From 6071d362922142b5543c7f61975ad6d59aa318eb Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 28 Jul 2025 19:44:03 +0000 Subject: [PATCH] Merge commit '504b101da33bd1ae2b39e13342c961eb0ddb4458' into develop --- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 3 +- CMakeLists.txt | 2 + Dockerfile | 1 + Jenkinsfile | 63 +- TERMINOLOGY.md | 348 +++- .../grouped_conv2d_fwd_ngchw.cpp | 6 +- .../grouped_conv2d_bwd_data.cpp | 6 +- .../grouped_conv2d_bwd_data_ngchw.cpp | 6 +- .../grouped_conv3d_bwd_data.cpp | 6 +- ..._conv3d_bwd_data_input_fp16_comp_bf8f8.cpp | 6 +- .../elementwise_layernorm2d.cpp | 2 +- client_example/15_reduce/reduce_nhwc_c.cpp | 18 +- ...d_conv_bwd_data_bilinear_residual_fp16.cpp | 6 +- .../grouped_conv_bwd_data_scale_fp16.cpp | 6 +- ...rouped_conv_fwd_bilinear_residual_fp16.cpp | 6 +- .../common.hpp | 32 +- .../grouped_conv_fwd_scale_fp16.cpp | 6 +- .../grouped_conv_fwd_scaleadd_ab.inc | 4 +- client_example/25_wrapper/wrapper_img2col.cpp | 6 +- client_example/CMakeLists.txt | 2 +- cmake/gtest.cmake | 3 + codegen/CMakeLists.txt | 2 +- codegen/include/ck/host/stringutils.hpp | 5 +- ...wd_multiple_abd_operation_xdl_cshuffle.cpp | 11 +- codegen/test/batched_gemm_softmax_gemm.cpp | 12 +- codegen/test/gemm_multiple_d.cpp | 10 +- codegen/test/rtc/include/rtc/tmp_dir.hpp | 2 +- codegen/test/rtc/src/compile_kernel.cpp | 4 +- .../Composable-Kernel-prerequisites.rst | 2 +- example/01_gemm/CMakeLists.txt | 2 + .../gemm_wmma_fp16_pk_i4_v3_b_scale.cpp | 367 ++++ example/01_gemm/gemm_xdl_fp64.cpp | 11 +- example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp | 6 +- example/12_reduce/reduce_blockwise_impl.hpp | 2 +- .../gemm_reduce_xdl_common.hpp | 6 +- .../batched_gemm_reduce_xdl_fp16.cpp | 6 +- .../run_layernorm_example.inc | 4 +- ...rouped_gemm_scale_softmax_gemm_permute.inc | 8 +- .../34_batchnorm/batchnorm_backward_nhwc.cpp | 4 +- .../batchnorm_forward_inferring_nhwc.cpp | 5 +- .../batchnorm_forward_training_nhwc.cpp | 7 +- ...tchnorm_forward_training_nhwc_obsolete.cpp | 7 +- .../sparse_embedding3_forward_layernorm.cpp | 8 +- .../common.hpp | 2 +- example/39_permute/common.hpp | 13 +- .../run_groupnorm_fwd_example.inc | 4 +- ...entwise_scale_permute_amax_2D_fp16_fp8.cpp | 6 +- .../contraction_multi_ABD_xdl_fp16.cpp | 2 +- .../contraction_multi_ABD_xdl_fp8.cpp | 4 +- .../convnd_fwd_convscale_reduce_common.hpp | 8 +- .../run_layernorm4d_fwd_example.inc | 4 +- .../moe_gemm1_xdl_pk_i4.cpp | 2 +- example/CMakeLists.txt | 168 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 552 +++--- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 38 +- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 13 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 28 +- example/ck_tile/01_fmha/generate.py | 3 - .../02_layernorm2d/layernorm2d_fwd.cpp | 3 +- example/ck_tile/03_gemm/README.md | 2 +- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 24 +- .../03_gemm/gemm_weight_preshuffle.cpp | 211 +-- example/ck_tile/03_gemm/run_gemm_example.inc | 36 +- example/ck_tile/03_gemm/universal_gemm.cpp | 209 +-- .../ck_tile/04_img2col/image_to_column.cpp | 14 +- .../matrix_core_swizzle_kernel.hpp | 14 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 42 +- example/ck_tile/10_rmsnorm2d/generate.py | 257 ++- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 38 +- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp | 2 + .../ck_tile/10_rmsnorm2d/script/perf_test.sh | 103 +- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 54 +- .../add_rmsnorm2d_rdquant_fwd.cpp | 4 +- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 3 +- .../12_smoothquant/example_smoothquant.cpp | 7 +- .../ck_tile/12_smoothquant/smoothquant.cpp | 5 +- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 105 +- .../13_moe_sorting/moe_sorting_api.cpp | 99 +- .../13_moe_sorting/moe_sorting_api.hpp | 12 +- .../13_moe_sorting/script/smoke_test.sh | 45 + .../14_moe_smoothquant/moe_smoothquant.cpp | 6 +- .../15_fused_moe/instances/fused_moe_api.cpp | 8 +- .../instances/fused_moegemm_api_internal.hpp | 10 +- .../instances/fused_moesorting_api.cpp | 63 +- example/ck_tile/15_fused_moe/main.cpp | 5 +- .../run_batched_gemm_example.inc | 29 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 7 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 2 +- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 7 +- .../run_grouped_gemm_example.inc | 32 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 4 + .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 9 +- .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 2 +- .../20_grouped_convolution/CMakeLists.txt | 6 +- .../grouped_convolution_backward_weight.cpp | 218 +++ .../grouped_convolution_forward.cpp | 8 +- .../grouped_convolution_utils.hpp | 27 +- ...grouped_convolution_bwd_weight_example.inc | 187 ++ ...> run_grouped_convolution_fwd_example.inc} | 38 +- example/ck_tile/21_elementwise/CMakeLists.txt | 15 + .../21_elementwise/elementwise_example.cpp | 214 +++ .../elementwise_example_add_4d.cpp | 159 ++ .../elementwise_example_transpose.cpp | 156 ++ .../elementwise_example_unary.cpp | 147 ++ .../batched_transpose_api.cpp | 215 ++- .../batched_transpose_example.cpp | 16 +- .../batched_transpose_example.hpp | 1 + .../35_batched_transpose/script/perf_test.sh | 12 +- .../35_batched_transpose/script/smoke_test.sh | 42 +- example/ck_tile/36_copy/CMakeLists.txt | 4 - example/ck_tile/36_copy/test_copy.cpp | 118 -- example/ck_tile/37_transpose/CMakeLists.txt | 9 - example/ck_tile/37_transpose/README.md | 27 - .../37_transpose/batched_transpose_kernel.hpp | 120 -- .../ck_tile/37_transpose/block_transpose.hpp | 149 -- .../ck_tile/37_transpose/transpose_api.cpp | 59 - .../38_block_scale_gemm/CMakeLists.txt | 13 + example/ck_tile/38_block_scale_gemm/README.md | 35 + .../38_block_scale_gemm/gemm_aquant_basic.cpp | 230 +++ .../38_block_scale_gemm/gemm_utils.hpp | 675 +++++++ .../run_gemm_aquant_example.inc | 259 +++ example/ck_tile/CMakeLists.txt | 4 +- example/ck_tile/remod.py | 2 +- include/ck/host_utility/hip_check_error.hpp | 5 +- include/ck/library/utility/algorithm.hpp | 8 +- include/ck/library/utility/fill.hpp | 7 +- include/ck/library/utility/host_tensor.hpp | 4 +- .../ck/tensor_description/tensor_adaptor.hpp | 24 +- .../tensor_description/tensor_descriptor.hpp | 12 +- .../tensor_space_filling_curve.hpp | 6 +- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 2 +- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 76 +- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 155 +- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 99 +- .../block/blockwise_gemm_pipeline_xdlops.hpp | 8 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 4 +- ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 234 ++- .../block/blockwise_gemm_smfmac_xdlops.hpp | 4 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 12 +- .../blockwise_gemm_xdlops_skip_b_lds.hpp | 2 +- ...roup_tensor_slice_transfer_direct_load.hpp | 6 +- ...nsor_slice_transfer_gather_direct_load.hpp | 12 +- .../gpu/device/device_base.hpp | 12 +- .../gpu/device/device_grouped_gemm.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 36 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 36 +- .../device_batched_gemm_e_permute_xdl.hpp | 28 +- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 37 +- .../impl/device_batched_gemm_multi_d_xdl.hpp | 34 +- .../device_batched_gemm_multiple_d_dl.hpp | 32 +- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 56 +- ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 8 +- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 38 +- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 64 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 46 +- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 40 +- .../device_batched_gemm_wmma_cshuffle_v3.hpp | 13 +- .../device/impl/device_batched_gemm_xdl.hpp | 4 +- ...evice_batched_gemm_xdl_fpAintB_b_scale.hpp | 8 +- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 4 +- .../impl/device_column_to_image_impl.hpp | 12 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 32 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 32 +- .../device/impl/device_contraction_utils.hpp | 10 +- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 5 +- ...shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 5 +- ...onv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 5 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 32 +- .../device/impl/device_gemm_multiple_d_dl.hpp | 28 +- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 85 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 40 +- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 30 +- .../impl/device_gemm_wmma_cshuffle_v3.hpp | 231 +-- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 302 ++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 265 +++ .../device_gemm_xdl_waveletmodel_cshuffle.hpp | 25 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 14 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 518 ++++-- .../device_grouped_conv_bwd_weight_dl.hpp | 22 +- ...e_grouped_conv_bwd_weight_explicit_xdl.hpp | 30 +- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 29 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 36 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 29 +- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 36 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 32 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 22 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 66 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 707 +++++--- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 44 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 18 +- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 14 +- .../device_grouped_gemm_multiple_d_dl.hpp | 12 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 20 +- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 20 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 18 +- .../device/impl/device_grouped_gemm_xdl.hpp | 12 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 20 +- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 16 +- ...e_grouped_query_attention_forward_wmma.hpp | 28 +- .../gpu/device/impl/device_moe_gemm.hpp | 50 +- .../impl/device_moe_gemm_blockscale.hpp | 116 +- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 112 +- ...ice_multi_query_attention_forward_wmma.hpp | 28 +- ...tk_contraction_multiple_d_xdl_cshuffle.hpp | 36 +- .../gpu/device/masking_specialization.hpp | 2 +- .../element/binary_element_wise_operation.hpp | 8 +- .../element/unary_element_wise_operation.hpp | 2 +- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 17 +- ...idwise_2d_reduction_threadwise_multi_d.hpp | 5 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 34 +- .../gpu/grid/gridwise_elementwise_2d.hpp | 99 +- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 28 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 48 +- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 16 +- .../gpu/grid/gridwise_gemm_dpp.hpp | 21 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 24 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 2 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 103 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 24 +- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 70 +- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 24 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 34 +- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 34 +- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 24 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 24 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1570 ++--------------- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 541 ++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 1420 +++++++++++++++ .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 19 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 33 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 43 +- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 41 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 23 +- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 33 +- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 26 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 49 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 378 ++-- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 35 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 61 +- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 51 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 32 +- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 28 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 38 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 54 +- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 24 +- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 27 +- .../gpu/grid/gridwise_gemm_xdlops_streamk.hpp | 34 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 50 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 22 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 27 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 26 +- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 32 +- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 38 +- .../gpu/grid/gridwise_moe_gemm.hpp | 192 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 190 +- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 113 +- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 123 +- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 125 +- .../gpu/grid/gridwise_permute.hpp | 2 +- .../gpu/grid/gridwise_tensor_rearrange.hpp | 16 +- .../gridwise_normalization_bwd_data.hpp | 2 +- .../threadwise_tensor_slice_transfer.hpp | 12 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 44 +- ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 18 +- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 12 +- .../threadwise_tensor_slice_transfer_v3r2.hpp | 12 +- .../threadwise_tensor_slice_transfer_v5r1.hpp | 12 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 12 +- .../transform_conv_bwd_data_to_gemm_v1.hpp | 127 +- include/ck/utility/amd_ck_fp8.hpp | 10 +- include/ck/utility/amd_xdlops.hpp | 32 +- include/ck/utility/container_helper.hpp | 2 +- include/ck/utility/data_type.hpp | 4 +- include/ck/utility/dynamic_buffer.hpp | 2 +- include/ck/utility/env.hpp | 1 + include/ck/utility/is_detected.hpp | 4 +- include/ck/utility/magic_division.hpp | 12 +- include/ck/utility/scaled_type_convert.hpp | 12 +- include/ck/utility/sequence.hpp | 4 +- include/ck/utility/synchronization.hpp | 2 +- include/ck/utility/type_convert.hpp | 67 +- include/ck/wrapper/tensor.hpp | 22 +- include/ck_tile/core.hpp | 1 + .../core/algorithm/coordinate_transform.hpp | 2 +- .../core/algorithm/space_filling_curve.hpp | 6 +- .../core/arch/amd_buffer_addressing.hpp | 108 +- .../arch/amd_buffer_addressing_builtins.hpp | 80 +- .../core/arch/amd_transpose_load_encoding.hpp | 58 +- include/ck_tile/core/arch/arch.hpp | 68 + .../core/container/container_helper.hpp | 2 +- include/ck_tile/core/container/sequence.hpp | 5 +- include/ck_tile/core/container/tuple.hpp | 27 +- include/ck_tile/core/numeric/float8.hpp | 2 +- include/ck_tile/core/numeric/math.hpp | 66 +- include/ck_tile/core/numeric/pk_int4.hpp | 18 + include/ck_tile/core/tensor/buffer_view.hpp | 65 +- .../core/tensor/load_tile_transpose.hpp | 324 ++-- include/ck_tile/core/tensor/sweep_tile.hpp | 2 +- .../ck_tile/core/tensor/tensor_adaptor.hpp | 32 +- .../ck_tile/core/tensor/tile_distribution.hpp | 10 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 5 +- .../core/tensor/tile_window_linear.hpp | 15 +- include/ck_tile/core/utility/debug.hpp | 156 ++ include/ck_tile/core/utility/type_traits.hpp | 4 +- .../core/utility/unary_element_function.hpp | 6 +- include/ck_tile/host.hpp | 2 + include/ck_tile/host/concat.hpp | 19 +- include/ck_tile/host/fill.hpp | 80 +- include/ck_tile/host/host_tensor.hpp | 2 +- include/ck_tile/host/joinable_thread.hpp | 2 +- .../host/reference/reference_elementwise.hpp | 2 +- .../ck_tile/host/reference/reference_gemm.hpp | 104 ++ .../reference_grouped_conv_bwd_weight.hpp | 167 ++ .../host/reference/reference_moe_sorting.hpp | 2 +- .../host/reference/reference_transpose.hpp | 33 + include/ck_tile/ops/batched_transpose.hpp | 4 + .../kernel/batched_transpose_kernel.hpp | 4 +- .../batched_transpose_common_policy.hpp | 33 + .../batched_transpose_lds_pipeline.hpp | 67 + .../pipeline/batched_transpose_lds_policy.hpp | 64 +- .../batched_transpose_lds_problem.hpp | 73 + .../pipeline/batched_transpose_pipeline.hpp | 15 +- .../pipeline/batched_transpose_policy.hpp | 34 +- .../pipeline/batched_transpose_problem.hpp | 31 +- include/ck_tile/ops/elementwise.hpp | 5 + .../binary_elementwise_operation.hpp | 94 + .../elementwise/kernel/elementwise_kernel.hpp | 123 ++ .../elementwise_pipeline_default_policy.hpp | 29 + .../pipeline/elementwise_pipeline_problem.hpp | 26 + .../pipeline/elementwise_shape.hpp | 29 + .../unary_element_wise_operation.hpp | 92 +- .../ops/epilogue/cshuffle_epilogue.hpp | 12 +- .../flatmm_32x512x128_1x4x1_16x16x32.hpp | 4 +- .../flatmm_sn_32x128x512_1x4x1_16x16x32.hpp | 4 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 12 +- .../ops/flatmm/pipeline/tile_flatmm_shape.hpp | 3 + include/ck_tile/ops/fmha.hpp | 6 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 10 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 3 - ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 11 + ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 2 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 4 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 397 ++++- .../fused_moe/kernel/moe_sorting_problem.hpp | 8 + .../fused_moegemm_pipeline_flatmm_ex.hpp | 50 +- include/ck_tile/ops/gemm.hpp | 8 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 46 +- .../block/block_universal_gemm_as_bs_cr.hpp | 154 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 166 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 1015 +---------- .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 185 ++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 17 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 165 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 1169 ++++++++++++ .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 63 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 26 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 137 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 28 +- ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 12 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 79 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 6 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 12 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 154 +- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 4 +- ...wp_pipeline_agmem_bgmem_creg_v1_policy.hpp | 12 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 84 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 280 +-- .../warp/warp_gemm_attribute_mfma_impl.hpp | 42 +- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 61 +- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 7 +- include/ck_tile/ops/gemm_group_quant.hpp | 16 + .../block_universal_gemm_as_aquant_bs_cr.hpp | 482 +++++ .../kernel/gemm_aquant_kernel.hpp | 679 +++++++ .../gemm_aquant_pipeline_ag_bg_cr_base.hpp | 53 + .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 93 + .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 475 +++++ .../pipeline/gemm_aquant_pipeline_problem.hpp | 121 ++ .../pipeline/gemm_group_quant_utils.hpp | 95 + .../pipeline/tile_gemm_aquant_traits.hpp | 34 + include/ck_tile/ops/grouped_convolution.hpp | 2 + ...ped_convolution_backward_weight_kernel.hpp | 862 +++++++++ .../grouped_convolution_forward_kernel.hpp | 123 +- .../utils/grouped_convolution_utils.hpp | 27 +- .../transform_conv_bwd_weight_to_gemm.hpp | 659 +++++++ .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- .../ops/reduce/block/block_reduce2d.hpp | 133 ++ include/ck_tile/ops/rmsnorm2d.hpp | 1 + .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 17 +- .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 9 + ...rm2d_fwd_pipeline_model_sensitive_pass.hpp | 228 +++ .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 5 +- .../pipeline/rmsnorm2d_fwd_traits.hpp | 31 +- include/ck_tile/ref/naive_attention.hpp | 24 +- include/ck_tile/remod.py | 16 +- .../cpu/reference_moe_gemm.hpp | 2 +- .../cpu/reference_moe_gemm1_blockscale.hpp | 2 +- .../gpu/reference_gemm.hpp | 20 +- .../device_column_to_image_instance.hpp | 36 +- .../device_image_to_column_instance.hpp | 36 +- .../gpu/gemm_b_scale.hpp | 24 +- ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 30 + ...p_gemm_xdl_universal_km_kn_mn_instance.hpp | 9 +- ...ce_grouped_conv_bwd_weight_dl_instance.hpp | 27 +- ..._grouped_conv_bwd_weight_wmma_instance.hpp | 18 +- .../gpu/CMakeLists.txt | 95 +- ...al_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gkm_gkn_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gkm_gnk_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gmk_gkn_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gmk_gnk_gmn_instance.cpp | 9 +- ...dl_int8_int8_int8_gkm_gnk_gmn_instance.cpp | 9 +- ..._data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp | 18 +- .../km_kn_mn_default_pipeline_v1_instance.cpp | 9 +- .../km_kn_mn_default_pipeline_v2_instance.cpp | 7 +- ...kn_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...m_kn_mn_interwave_pipeline_v1_instance.cpp | 7 +- .../km_nk_mn_default_pipeline_v1_instance.cpp | 9 +- .../km_nk_mn_default_pipeline_v2_instance.cpp | 7 +- ...nk_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...m_nk_mn_interwave_pipeline_v1_instance.cpp | 7 +- .../mk_kn_mn_default_pipeline_v1_instance.cpp | 9 +- .../mk_kn_mn_default_pipeline_v2_instance.cpp | 7 +- ...kn_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...k_kn_mn_interwave_pipeline_v1_instance.cpp | 7 +- ...gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp | 9 +- ...le_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 9 +- .../gpu/gemm_b_scale/CMakeLists.txt | 6 +- ..._gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp | 72 + ...4_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ...e_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp | 4 +- ...evice_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp | 17 +- ...wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_km_kn_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_km_nk_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_km_kn_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_km_nk_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...mm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp | 9 +- ...emm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...emm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp | 9 +- ...gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp | 9 +- ...gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp | 84 + ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 8 +- ...8_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 8 +- ..._bf16_mk_nk_mn_mem_v1_default_instance.cpp | 8 + ...bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 8 + ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 8 + ...bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 8 + ...versal_streamk_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 9 +- ..._universal_streamk_f16_f8_f16_mk_kn_mn.hpp | 7 +- ..._universal_streamk_f16_f8_f16_mk_nk_mn.hpp | 7 +- ..._universal_streamk_f8_f16_f16_mk_kn_mn.hpp | 7 +- ..._universal_streamk_f8_f16_f16_mk_nk_mn.hpp | 7 +- ...ata_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp | 8 + ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 8 + ...le_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp | 9 +- library/src/utility/convolution_parameter.cpp | 5 +- .../profiler/profile_conv_bwd_data_impl.hpp | 6 +- .../profiler/profile_conv_fwd_impl.hpp | 6 +- .../profile_conv_tensor_rearrange_impl.hpp | 5 +- .../profiler/profile_gemm_b_scale_impl.hpp | 4 +- .../include/profiler/profile_gemm_impl.hpp | 4 + .../profile_grouped_conv_bwd_data_impl.hpp | 7 +- .../profile_grouped_conv_bwd_weight_impl.hpp | 19 +- ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 10 +- .../profile_grouped_conv_fwd_impl.hpp | 6 +- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 6 +- .../include/profiler/profile_softmax_impl.hpp | 23 +- profiler/src/CMakeLists.txt | 10 +- profiler/src/profile_batched_gemm_b_scale.cpp | 3 +- profiler/src/profile_contraction_bilinear.cpp | 3 +- profiler/src/profile_contraction_scale.cpp | 3 +- profiler/src/profile_gemm_b_scale.cpp | 3 +- .../profile_grouped_conv_fwd_bias_clamp.cpp | 191 ++ .../src/profile_grouped_conv_fwd_clamp.cpp | 194 ++ script/clang-format-overwrite.sh | 4 +- script/cmake-ck-dev.sh | 26 +- script/cmake-ck-release.sh | 7 +- test/CMakeLists.txt | 4 + test/block_swizzle_test/rebuild.sh | 2 +- test/ck_tile/CMakeLists.txt | 15 + .../add_rmsnorm2d_rdquant/CMakeLists.txt | 26 + .../add_rmsnorm2d_rdquant_fwd.hpp | 151 ++ .../add_rmsnorm2d_rdquant_fwd.inc | 370 ++++ .../add_rmsnorm2d_rdquant_fwd_bf16.cpp | 6 + .../add_rmsnorm2d_rdquant_fwd_fp16.cpp | 6 + .../add_rmsnorm2d_rdquant_fwd_api.cpp | 227 +++ ...norm2d_rdquant_fwd_bf16_n1024_instance.cpp | 26 + ...norm2d_rdquant_fwd_bf16_n1536_instance.cpp | 17 + ...norm2d_rdquant_fwd_bf16_n2048_instance.cpp | 18 + ...snorm2d_rdquant_fwd_bf16_n256_instance.cpp | 15 + ...norm2d_rdquant_fwd_bf16_n3072_instance.cpp | 17 + ...norm2d_rdquant_fwd_bf16_n4096_instance.cpp | 17 + ...snorm2d_rdquant_fwd_bf16_n512_instance.cpp | 17 + ...m2d_rdquant_fwd_bf16_n64_n128_instance.cpp | 15 + ...snorm2d_rdquant_fwd_bf16_n768_instance.cpp | 15 + ...norm2d_rdquant_fwd_bf16_n8192_instance.cpp | 42 + ...m2d_rdquant_fwd_bf16_n8192_tp_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n1024_instance.cpp | 26 + ...norm2d_rdquant_fwd_fp16_n1536_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n2048_instance.cpp | 18 + ...snorm2d_rdquant_fwd_fp16_n256_instance.cpp | 15 + ...norm2d_rdquant_fwd_fp16_n3072_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n4096_instance.cpp | 17 + ...snorm2d_rdquant_fwd_fp16_n512_instance.cpp | 17 + ...m2d_rdquant_fwd_fp16_n64_n128_instance.cpp | 15 + ...snorm2d_rdquant_fwd_fp16_n768_instance.cpp | 15 + ...norm2d_rdquant_fwd_fp16_n8192_instance.cpp | 41 + ...m2d_rdquant_fwd_fp16_n8192_tp_instance.cpp | 17 + ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 70 + .../batched_gemm/test_batched_gemm_util.hpp | 29 +- test/ck_tile/batched_transpose/CMakeLists.txt | 33 + .../batched_transpose/batched_transpose.hpp | 8 +- .../batched_transpose/batched_transpose.inc | 166 +- .../batched_transpose_api.cpp | 109 ++ .../batched_transpose_bf16.cpp | 10 + .../batched_transpose_fp16.cpp | 10 + .../batched_transpose_fp8.cpp | 10 + test/ck_tile/container/CMakeLists.txt | 6 + test/ck_tile/container/test_tuple_apply.cpp | 223 +++ test/ck_tile/data_type/test_pk_int4.cpp | 8 +- test/ck_tile/elementwise/CMakeLists.txt | 6 + .../elementwise/test_elementwise_1d.cpp | 210 +++ test/ck_tile/gemm/CMakeLists.txt | 19 + .../gemm/test_gemm_pipeline_basic_bf16.cpp | 5 + .../gemm/test_gemm_pipeline_basic_bf8.cpp | 5 + .../gemm/test_gemm_pipeline_basic_fp16.cpp | 5 + .../gemm/test_gemm_pipeline_basic_fp8.cpp | 5 + .../test_gemm_pipeline_basic_run_test.inc | 313 ++++ .../test_gemm_pipeline_smoke_run_test.inc | 456 +++++ .../gemm/test_gemm_pipeline_smoke_util.hpp | 414 +++++ .../test_gemm_pipeline_universal_bf16.cpp | 16 + .../gemm/test_gemm_pipeline_universal_bf8.cpp | 16 + .../test_gemm_pipeline_universal_fp16.cpp | 16 + .../gemm/test_gemm_pipeline_universal_fp8.cpp | 16 + .../test_gemm_pipeline_universal_run_test.inc | 394 +++++ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 41 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 19 + .../test_gemm_aquant_basic_bf8.cpp | 6 + .../test_gemm_aquant_basic_fp8.cpp | 6 + .../test_gemm_aquant_basic_i4bf8.cpp | 6 + .../test_gemm_aquant_basic_i4f32bf8.cpp | 6 + .../test_gemm_aquant_basic_i4f32fp8.cpp | 6 + .../test_gemm_aquant_basic_i4fp8.cpp | 6 + .../test_gemm_aquant_utils.hpp | 681 +++++++ .../test_run_gemm_aquant_example.inc | 580 ++++++ .../gemm_multi_d/test_gemm_multi_d_util.hpp | 30 +- .../test_gemm_pipeline_kernel_types.hpp | 20 +- .../test_gemm_pipeline_ut_cases.inc | 119 +- .../test_gemm_pipeline_util.hpp | 37 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 57 +- test/ck_tile/layernorm2d/CMakeLists.txt | 53 + test/ck_tile/layernorm2d/generate.py | 730 ++++++++ test/ck_tile/layernorm2d/layernorm2d_fwd.hpp | 70 + test/ck_tile/layernorm2d/layernorm2d_fwd.inc | 565 ++++++ .../layernorm2d/layernorm2d_fwd_bf16.cpp | 6 + .../layernorm2d/layernorm2d_fwd_fp16.cpp | 6 + test/ck_tile/memory_copy/CMakeLists.txt | 3 + .../ck_tile/memory_copy}/README.md | 0 test/ck_tile/memory_copy/test_copy.cpp | 193 ++ .../ck_tile/memory_copy}/test_copy.hpp | 116 +- test/ck_tile/moe_smoothquant/CMakeLists.txt | 32 + .../moe_smoothquant_bf16_n1024_instance.cpp | 27 + .../moe_smoothquant_bf16_n1536_instance.cpp | 18 + .../moe_smoothquant_bf16_n2048_instance.cpp | 19 + .../moe_smoothquant_bf16_n256_instance.cpp | 16 + .../moe_smoothquant_bf16_n3072_instance.cpp | 18 + .../moe_smoothquant_bf16_n4096_instance.cpp | 18 + ...moe_smoothquant_bf16_n4096_tp_instance.cpp | 18 + .../moe_smoothquant_bf16_n512_instance.cpp | 18 + ...moe_smoothquant_bf16_n64_n128_instance.cpp | 16 + .../moe_smoothquant_bf16_n768_instance.cpp | 16 + .../moe_smoothquant_fp16_n1024_instance.cpp | 27 + .../moe_smoothquant_fp16_n1536_instance.cpp | 18 + .../moe_smoothquant_fp16_n2048_instance.cpp | 18 + .../moe_smoothquant_fp16_n256_instance.cpp | 16 + .../moe_smoothquant_fp16_n3072_instance.cpp | 18 + .../moe_smoothquant_fp16_n4096_instance.cpp | 18 + ...moe_smoothquant_fp16_n4096_tp_instance.cpp | 18 + .../moe_smoothquant_fp16_n512_instance.cpp | 18 + ...moe_smoothquant_fp16_n64_n128_instance.cpp | 16 + .../moe_smoothquant_fp16_n768_instance.cpp | 16 + .../instances/moe_smoothquant_fwd_api.cpp | 155 ++ .../moe_smoothquant_instance_common.hpp | 65 + .../moe_smoothquant/moe_smoothquant.hpp | 104 ++ .../moe_smoothquant/moe_smoothquant.inc | 317 ++++ .../moe_smoothquant_bf16_fp8.cpp | 11 + .../moe_smoothquant_bf16_int8.cpp | 11 + .../moe_smoothquant_fp16_fp8.cpp | 11 + .../moe_smoothquant_fp16_int8.cpp | 11 + test/ck_tile/moe_sorting/CMakeLists.txt | 15 + test/ck_tile/moe_sorting/moe_sorting_api.cpp | 444 +++++ test/ck_tile/moe_sorting/moe_sorting_api.hpp | 33 + test/ck_tile/moe_sorting/moe_sorting_fp32.cpp | 544 ++++++ test/ck_tile/permute/CMakeLists.txt | 33 + .../alternative_impl/matrix_core_swizzle.cpp | 101 ++ .../alternative_impl/matrix_core_swizzle.hpp | 20 + .../matrix_core_swizzle_kernel.hpp | 413 +++++ test/ck_tile/permute/permute.hpp | 19 + test/ck_tile/permute/permute_fp16.cpp | 29 + test/ck_tile/permute/permute_fp32.cpp | 29 + test/ck_tile/permute/permute_fp8.cpp | 29 + test/ck_tile/permute/permute_utils.inc | 490 +++++ test/ck_tile/rmsnorm2d/CMakeLists.txt | 54 + test/ck_tile/rmsnorm2d/generate.py | 715 ++++++++ test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp | 69 + test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc | 618 +++++++ test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp | 5 + test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp | 5 + test/ck_tile/smoothquant/CMakeLists.txt | 28 + .../smoothquant_bf16_n1024_instance.cpp | 21 + .../smoothquant_bf16_n1536_instance.cpp | 12 + .../smoothquant_bf16_n2048_instance.cpp | 13 + .../smoothquant_bf16_n256_instance.cpp | 11 + .../smoothquant_bf16_n3072_instance.cpp | 13 + .../smoothquant_bf16_n4096_instance.cpp | 13 + .../smoothquant_bf16_n4096_tp_instance.cpp | 13 + .../smoothquant_bf16_n512_instance.cpp | 12 + .../smoothquant_bf16_n64_n128_instance.cpp | 11 + .../smoothquant_bf16_n768_instance.cpp | 11 + .../smoothquant_fp16_n1024_instance.cpp | 21 + .../smoothquant_fp16_n1536_instance.cpp | 12 + .../smoothquant_fp16_n2048_instance.cpp | 13 + .../smoothquant_fp16_n256_instance.cpp | 11 + .../smoothquant_fp16_n3072_instance.cpp | 13 + .../smoothquant_fp16_n4096_instance.cpp | 13 + .../smoothquant_fp16_n4096_tp_instance.cpp | 13 + .../smoothquant_fp16_n512_instance.cpp | 12 + .../smoothquant_fp16_n64_n128_instance.cpp | 11 + .../smoothquant_fp16_n768_instance.cpp | 11 + .../instances/smoothquant_fwd_api.cpp | 143 ++ .../instances/smoothquant_instance_common.hpp | 61 + test/ck_tile/smoothquant/smoothquant.hpp | 114 ++ test/ck_tile/smoothquant/smoothquant.inc | 273 +++ test/ck_tile/smoothquant/smoothquant_bf16.cpp | 11 + test/ck_tile/smoothquant/smoothquant_fp16.cpp | 11 + test/ck_tile/topk_softmax/CMakeLists.txt | 19 + .../topk_softmax/test_topk_softmax.hpp | 280 +++ .../topk_softmax/test_topk_softmax_api.cpp | 96 + .../topk_softmax/test_topk_softmax_api.hpp | 21 + .../topk_softmax/test_topk_softmax_bf16.cpp | 6 + .../topk_softmax/test_topk_softmax_fp16.cpp | 6 + test/data_type/test_bhalf.cpp | 46 + test/data_type/test_mx_fp4.cpp | 4 +- test/data_type/test_pk_i4.cpp | 8 +- test/gemm_b_scale/CMakeLists.txt | 9 + .../test_gemm_b_scale_ut_cases.inc | 43 + test/gemm_b_scale/test_gemm_b_scale_util.hpp | 97 + test/gemm_b_scale/test_gemm_b_scale_wmma.cpp | 45 + test/gemm_b_scale/test_gemm_b_scale_xdl.cpp | 45 + test/mx_mfma_op/mx_mfma_op.cpp | 180 +- test/pool/test_max_pool2d_fwd.cpp | 4 +- .../reference_conv_fwd/reference_conv_fwd.cpp | 12 +- test/scatter_gather/scatter_gather.cpp | 4 +- tile_engine/ops/gemm/CMakeLists.txt | 56 +- tile_engine/ops/gemm/README.md | 26 +- tile_engine/ops/gemm/benchmark_gemm.hpp | 6 +- tile_engine/ops/gemm/configs/benchmark.json | 15 - .../ops/gemm/configs/custom_ci_config.json | 82 + .../ops/gemm/configs/default_config.json | 15 - .../gemm/configs/user_provided_config.json | 15 - tile_engine/ops/gemm/gemm_instance_builder.py | 60 +- tile_engine/ops/gemm/gemm_profiler.hpp | 14 +- tile_engine/ops/gemm/json_config.py | 48 +- 710 files changed, 34610 insertions(+), 9666 deletions(-) create mode 100644 example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp create mode 100644 example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc rename example/ck_tile/20_grouped_convolution/{run_grouped_convolution_example.inc => run_grouped_convolution_fwd_example.inc} (81%) create mode 100644 example/ck_tile/21_elementwise/CMakeLists.txt create mode 100644 example/ck_tile/21_elementwise/elementwise_example.cpp create mode 100644 example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp create mode 100644 example/ck_tile/21_elementwise/elementwise_example_transpose.cpp create mode 100644 example/ck_tile/21_elementwise/elementwise_example_unary.cpp delete mode 100644 example/ck_tile/36_copy/CMakeLists.txt delete mode 100644 example/ck_tile/36_copy/test_copy.cpp delete mode 100644 example/ck_tile/37_transpose/CMakeLists.txt delete mode 100644 example/ck_tile/37_transpose/README.md delete mode 100644 example/ck_tile/37_transpose/batched_transpose_kernel.hpp delete mode 100644 example/ck_tile/37_transpose/block_transpose.hpp delete mode 100644 example/ck_tile/37_transpose/transpose_api.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/CMakeLists.txt create mode 100644 example/ck_tile/38_block_scale_gemm/README.md create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_utils.hpp create mode 100644 example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp create mode 100644 include/ck_tile/core/utility/debug.hpp create mode 100644 include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp create mode 100644 include/ck_tile/host/reference/reference_transpose.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp rename example/ck_tile/37_transpose/transpose_policy.hpp => include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp (63%) create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp create mode 100644 include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp create mode 100644 include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp create mode 100644 include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp create mode 100644 profiler/src/profile_grouped_conv_fwd_clamp.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp create mode 100644 test/ck_tile/batched_transpose/CMakeLists.txt rename example/ck_tile/37_transpose/transpose_example.hpp => test/ck_tile/batched_transpose/batched_transpose.hpp (68%) rename example/ck_tile/37_transpose/transpose_example.cpp => test/ck_tile/batched_transpose/batched_transpose.inc (59%) create mode 100644 test/ck_tile/batched_transpose/batched_transpose_api.cpp create mode 100644 test/ck_tile/batched_transpose/batched_transpose_bf16.cpp create mode 100644 test/ck_tile/batched_transpose/batched_transpose_fp16.cpp create mode 100644 test/ck_tile/batched_transpose/batched_transpose_fp8.cpp create mode 100644 test/ck_tile/container/CMakeLists.txt create mode 100644 test/ck_tile/container/test_tuple_apply.cpp create mode 100644 test/ck_tile/elementwise/CMakeLists.txt create mode 100644 test/ck_tile/elementwise/test_elementwise_1d.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc create mode 100644 test/ck_tile/gemm_block_scale/CMakeLists.txt create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc create mode 100644 test/ck_tile/layernorm2d/CMakeLists.txt create mode 100644 test/ck_tile/layernorm2d/generate.py create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd.hpp create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd.inc create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd_bf16.cpp create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd_fp16.cpp create mode 100644 test/ck_tile/memory_copy/CMakeLists.txt rename {example/ck_tile/36_copy => test/ck_tile/memory_copy}/README.md (100%) create mode 100644 test/ck_tile/memory_copy/test_copy.cpp rename {example/ck_tile/36_copy => test/ck_tile/memory_copy}/test_copy.hpp (56%) create mode 100644 test/ck_tile/moe_smoothquant/CMakeLists.txt create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp create mode 100644 test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant.hpp create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant.inc create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant_bf16_fp8.cpp create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant_bf16_int8.cpp create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant_fp16_fp8.cpp create mode 100644 test/ck_tile/moe_smoothquant/moe_smoothquant_fp16_int8.cpp create mode 100644 test/ck_tile/moe_sorting/CMakeLists.txt create mode 100644 test/ck_tile/moe_sorting/moe_sorting_api.cpp create mode 100644 test/ck_tile/moe_sorting/moe_sorting_api.hpp create mode 100644 test/ck_tile/moe_sorting/moe_sorting_fp32.cpp create mode 100644 test/ck_tile/permute/CMakeLists.txt create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle.cpp create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp create mode 100644 test/ck_tile/permute/permute.hpp create mode 100644 test/ck_tile/permute/permute_fp16.cpp create mode 100644 test/ck_tile/permute/permute_fp32.cpp create mode 100644 test/ck_tile/permute/permute_fp8.cpp create mode 100644 test/ck_tile/permute/permute_utils.inc create mode 100644 test/ck_tile/rmsnorm2d/CMakeLists.txt create mode 100644 test/ck_tile/rmsnorm2d/generate.py create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp create mode 100644 test/ck_tile/smoothquant/CMakeLists.txt create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n1024_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n1536_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n2048_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n256_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n3072_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n512_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_bf16_n768_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n1024_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n1536_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n2048_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n256_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n3072_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n512_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fp16_n768_instance.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_fwd_api.cpp create mode 100644 test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp create mode 100644 test/ck_tile/smoothquant/smoothquant.hpp create mode 100644 test/ck_tile/smoothquant/smoothquant.inc create mode 100644 test/ck_tile/smoothquant/smoothquant_bf16.cpp create mode 100644 test/ck_tile/smoothquant/smoothquant_fp16.cpp create mode 100644 test/ck_tile/topk_softmax/CMakeLists.txt create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax.hpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_api.cpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_api.hpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_bf16.cpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_fp16.cpp create mode 100644 test/gemm_b_scale/CMakeLists.txt create mode 100644 test/gemm_b_scale/test_gemm_b_scale_ut_cases.inc create mode 100644 test/gemm_b_scale/test_gemm_b_scale_util.hpp create mode 100644 test/gemm_b_scale/test_gemm_b_scale_wmma.cpp create mode 100644 test/gemm_b_scale/test_gemm_b_scale_xdl.cpp create mode 100644 tile_engine/ops/gemm/configs/custom_ci_config.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4e85651f6..664c5219e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: clang-format name: clang-format - entry: clang-format-12 -i --style=file + entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - id: copyright-year-checker diff --git a/CHANGELOG.md b/CHANGELOG.md index 17f9455feb..fa3ba71143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## Composable Kernel 1.1.0 for ROCm 6.5.0 +## Composable Kernel 1.1.0 for ROCm 7.0.0 ### Added @@ -23,6 +23,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. +* Added support for elementwise kernel. ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e032a30cf..da5a86523e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -236,6 +236,8 @@ endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) set(CK_USE_NATIVE_MX_SUPPORT "ON") + add_definitions(-DCK_GFX950_SUPPORT) + set(CK_GFX950_SUPPORT "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) diff --git a/Dockerfile b/Dockerfile index 0219f99238..6f5cd0115d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libzstd-dev \ openssh-server \ clang-format-12 \ + clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/Jenkinsfile b/Jenkinsfile index 50c15701a7..b34e366f1b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -234,11 +234,6 @@ def cmake_build(Map conf=[:]){ def build_type_debug = (conf.get("build_type",'release') == 'debug') - // use special compiler for gfx950 - if ( check_arch() == 7){ - compiler = "/llvm-project/build/bin/clang++" - } - //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") @@ -600,7 +595,7 @@ def Build_CK(Map conf=[:]){ if (params.RUN_FULL_QA && arch == 2 ){ // build deb packages echo "Build packages" - sh 'make -j package' + sh 'ninja package' archiveArtifacts artifacts: 'composablekernel*.deb' sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' @@ -814,7 +809,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true @@ -919,8 +914,8 @@ pipeline { description: "Build CK and run tests on gfx90a (default: ON)") booleanParam( name: "BUILD_GFX942", - defaultValue: true, - description: "Build CK and run tests on gfx942 (default: ON)") + defaultValue: false, + description: "Build CK and run tests on gfx942 (default: OFF)") booleanParam( name: "BUILD_GFX950", defaultValue: false, @@ -999,7 +994,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1028,7 +1023,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1234,11 +1229,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j64 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j64 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1259,11 +1267,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j128 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j128 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1352,12 +1373,12 @@ pipeline { execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx950" \ - -DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md index e8833efb89..6dbe88640c 100644 --- a/TERMINOLOGY.md +++ b/TERMINOLOGY.md @@ -1,2 +1,348 @@ [Back to the main page](./README.md) -# Composable Kernel terminology \ No newline at end of file + +# Composable Kernel Terminology + +This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations. + +--- + +## Glossary Index (Alphabetical) + +- [Add+Multiply](#addmultiply) +- [Bank Conflict](#bank-conflict) +- [Batched GEMM](#batched-gemm) +- [Benchmark](#benchmark) +- [Block Size](#block-size) +- [Block Tile](#block-tile) +- [Compute Unit (CU)](#compute-unit-cu) +- [Coordinate Transformation Primitives](#coordinate-transformation-primitives) +- [CUDA](#cuda) +- [Dense Tensor](#dense-tensor) +- [Descriptor](#descriptor) +- [Device](#device) +- [Elementwise](#elementwise) +- [Epilogue](#epilogue) +- [Fast Changing Dimension](#fast-changing-dimension) +- [GEMM](#gemm-general-matrix-multiply) +- [GEMV](#gemv) +- [Grouped GEMM](#grouped-gemm) +- [Global Memory](#global-memory) +- [Grid](#grid) +- [Host](#host) +- [HIP](#hip) +- [Inner Dimension](#inner-dimension) +- [Inner Product](#inner-product) +- [Input/Problem Shape](#inputproblem-shape) +- [Kernel](#kernel) +- [Launch Parameters](#launch-parameters) +- [Load Tile](#load-tile) +- [LDS Banks](#lds-banks) +- [Matrix Core](#matrix-core) +- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add) +- [Occupancy](#occupancy) +- [Outer Dimension](#outer-dimension) +- [Outer Product](#outer-product) +- [Pinned Memory](#pinned-memory) +- [Pipeline](#pipeline) +- [Policy](#policy) +- [Problem](#problem) +- [Processing Units](#processing-units) +- [Reference Kernel](#reference-kernel) +- [Regression Test](#regression-test) +- [ROCm](#rocm) +- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr) +- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share) +- [SIMT / SIMD](#simt--simd) +- [Smoke Test](#smoke-test) +- [Sparse Tensor](#sparse-tensor) +- [Split-K GEMM](#split-k-gemm) +- [Store Tile](#store-tile) +- [Thread / Work-item](#thread--work-item) +- [Thread Block / Work Group](#thread-block--work-group) +- [Vanilla GEMM](#vanilla-gemm) +- [Tile](#tile) +- [Tile Distribution](#tile-distribution) +- [Tile Partitioner](#tile-partitioner) +- [Tile Programming API](#tile-programming-api) +- [Tile Window](#tile-window) +- [User Customized Tile Pipeline](#user-customized-tile-pipeline) +- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization) +- [Vector](#vector) +- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr) +- [Warp / Wavefront](#warp--wavefront) +- [Wave Tile](#wave-tile) +- [XDL Instructions](#xdl-instructions) + +--- + +## 1. Hardware and Memory + +### Processing Units +The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale. + +### Matrix Core +Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores. + +### Compute Unit (CU) +AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_. + +### Matrix Fused Multiply-Add (MFMA) +AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance. + +### Registers +The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file. + +### Vector General Purpose Register (VGPR) +Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations. + +### Scalar General Purpose Register (SGPR) +Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave. + +### Shared Memory / Local Data Share (LDS) +AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict). + +### LDS Banks +Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth. + +### Global Memory +The main device memory accessible by all threads, offering high capacity but higher latency than shared memory. + +### Pinned Memory +Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements. + +### Dense Tensor +A tensor in which most elements are nonzero, typically stored in a contiguous block of memory. + +### Sparse Tensor +A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices. + +### Host +CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation. + +### Device +GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators. + +--- + +## 2. GPU Programming Model + +### Thread / Work-item +AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing. + +### Warp / Wavefront +AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp. + +### Thread Block / Work Group +AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution. + +### Grid +The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU. + +### Block Size +Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage. + +### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD) +SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures. + +### Occupancy +The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput. + +--- + +## 3. Kernel Structure + +### Kernel +A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component. + +### Pipeline +A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around. + +### Tile Partitioner +Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes. + +### Problem +Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution). + +### Policy +Defines memory access patterns and hardware-specific optimizations. + +### User Customized Tile Pipeline +User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### User Customized Tile Pipeline Optimization +Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### Tile Programming API +CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store. + +### Coordinate Transformation Primitives +CK utilities for converting between different coordinate systems (logical, physical, memory layouts). + +### Reference Kernel +A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU. + +### Launch Parameters +Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance. + +--- + +## 4. Memory Access and Data Layout + +### Memory Coalescing +An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth. + +### Alignment +A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value. + +### Bank Conflict +Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth. + +### Padding +The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access. + +### Permute/Transpose +Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns. + +### Host-Device Transfer +The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations. + +### Stride +The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step. + +### Dilation +The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size. + +### Im2Col/Col2Im +Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout. + +### Fast Changing Dimension +Innermost dimension that changes fastest in memory layout. + +### Outer Dimension +Slower-changing dimension in memory layout. + +### Inner Dimension +Faster-changing dimension in memory layout. + +--- + +## 5. Tile-Based Computing and Data Structures + +### Tile +A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms. + +### Block Tile +Memory tile processed by a work group (thread block). + +### Wave Tile +Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution. + +### Tile Distribution +Hierarchical data mapping from work-items to data in memory. + +### Tile Window +Viewport into a larger tensor that defines the current tile's position and boundaries for computation. + +### Load Tile +Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns. + +### Store Tile +Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns. + +### Descriptor +Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations. + +### Input/Problem Shape +Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM). + +### Vector +Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware. + +--- + +## 6. Kernel Operations and Optimization + +### Elementwise +Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access. + +### Epilogue +The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs. + +### Add+Multiply +A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers. + +--- + +## 7. Linear Algebra and ML Operations + +### General Matrix Multiply (GEMM) +Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C. + +### "Vanilla" GEMM (Naive GEMM) Kernel +The **vanilla GEMM** is the simplest form of GEMM in CK. It: +- Takes input matrices **A** and **B** +- Multiplies them to produce output matrix **C** + +This is the **baseline** or **building block** GEMM that all other complex versions expand upon. + +### Grouped GEMM (GGEMMs) + +A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel. + +### Batched GEMM +A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape. + +### Split-K GEMM +A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications. + +### GEMV +The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing. + +### Inner Product +Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar. + +### Outer Product +The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms. + +### Norm +A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization. + +--- + +## 8. Testing, Build, and Infrastructure + +### Regression Test +Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942. + +### Smoke Test +Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942. + +--- + +## 9. Low-Level Instructions and Optimizations + +### eXtensible Data Language (XDL) Instructions +eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks. + +--- + +## 10. Miscellaneous + +### HIP +AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures. + +### CUDA +NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources. + +### ROCm +AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs. + +--- + +## Scientific Context and References + +This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see: +- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html) +- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html) + +This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles. diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp index 480abf23d2..13f1a3acc1 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -107,14 +107,14 @@ int execute_conv_fwd() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index ae5f1b6f6e..f31ffe302a 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -130,14 +130,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp index 2309d757f0..a9918f6ab3 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -105,14 +105,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index 93709a7901..baa2b02bce 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -109,14 +109,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index a62a1d911b..ac7eb3cf41 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -111,14 +111,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 69d7c8936c..37cafc190e 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -59,7 +59,7 @@ int main() SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); std::array ab_input = {a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer()}; + b_dev_buf.GetDeviceBuffer()}; std::vector abStride = {Stride, 1}; std::array, 2> abStrides = {abStride, abStride}; diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index e2b1fbcb54..12aa31dec3 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -68,15 +68,15 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); using DeviceOp = ck::tensor_operation::device::DeviceReduce; + AccDataType, + OutDataType, + Rank, + NumReduceDim, + ReduceAdd, + PassThrough, + UnaryDivide, + PropagateNan, + OutputIndex>; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp index bb106e8d8e..e8e33a3de2 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {in.GetDeviceBuffer()}, + {in.GetDeviceBuffer()}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {in_lengths}, - {in_strides}, + {in_lengths}, + {in_strides}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp index e53ecc6c99..d81b5fd03e 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp index 32ab481319..2ec70b8b9b 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {out.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {out_lengths}, - {out_strides}, + {out_lengths}, + {out_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp index c78cacf266..98f41dc7fb 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce( ck::tensor_operation::element_wise::Scale{scale_wei}, {}}; auto conv_ok = ConvolutionScale(in, + WeiDataType, + ConvOutDataType, + ConvElementOp, + InLayout, + WeiLayout, + OutLayout, + NumDimSpatial>(in, wei, conv_out, elementwise_op, @@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor, { std::cout << "\nReduction of spatial dimensions:" << std::endl; using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + OutDataType, + OutDataType, + NumDimSpatial, + NumDimSpatial, + ReduceOperation, + PassThrough, + AccElementwiseOperation, + true, // PropagateNan + false>; // OutputIndex const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp index 11e69f5bb2..11f24b39c7 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -120,14 +120,14 @@ int execute_conv_fwd_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 3f6f7b0773..4cf3a4cf82 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab() in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index ceccc5eb8f..f7f893fda2 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G, ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; + decltype(output_tensor_global), + decltype(tile_shape), + decltype(thread_layout)>; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, dim3(grid_size_x, grid_size_y, 1), diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 8fdd60f5d5..f27e557cc3 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.15) project(ck_app) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) if (DTYPES) add_definitions(-DDTYPES) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..6587f4c4be 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,3 +68,6 @@ endif() target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) +target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) + diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 35b5cf0367..2b2e6e2949 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -22,7 +22,7 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 89c1884d2e..81b312ec95 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector -inline auto Transform(const Range1& r1, const Range2& r2, F f) - -> std::vector +inline auto Transform(const Range1& r1, + const Range2& r2, + F f) -> std::vector { std::vector result; assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index 36c9a13b4c..a2f322c50f 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -142,12 +142,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr x.A = TensorDesc{prob.ADataType, prob.ALayout}; x.B = TensorDesc{prob.BDataType, prob.BLayout}; x.E = TensorDesc{prob.EDataType, prob.ELayout}; - x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { - return TensorDesc{dt, lo}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; + x.Ds = Transform( + prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; x.update_prologue(prologue); x.update_epilogue(epilogue); result.push_back(x); diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 13035df355..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index adc8e1ff02..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 2f3b26cc43..f4983debd9 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -16,7 +16,7 @@ struct tmp_dir void execute(const std::string& cmd) const; - tmp_dir(tmp_dir const&) = delete; + tmp_dir(tmp_dir const&) = delete; tmp_dir& operator=(tmp_dir const&) = delete; ~tmp_dir(); diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 262e6bae46..fac92ded7d 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -94,7 +94,7 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o assert(not srcs.empty()); tmp_dir td{"compile"}; options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " --offload-arch=" + get_device_name(); std::string out; @@ -278,7 +278,7 @@ std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, compile_options options) { options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " -DCK_CODE_GEN_RTC"; options.flags += " --offload-arch=" + get_device_name(); auto cos = compile_hip_src_with_hiprtc(srcs, options); diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst index 10be849ea6..9dc082599a 100644 --- a/docs/install/Composable-Kernel-prerequisites.rst +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel: * zlib1g-dev * libzstd-dev * openssh-server -* clang-format-12 +* clang-format-18 diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e6a26ecafd..61f3ba5351 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..d3ac184019 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::string device_name = ck::get_device_name(); + if(!(device_name.find("gfx11") != std::string::npos || + device_name.find("gfx12") != std::string::npos)) + { + std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 5afb3d1554..b55627f3ee 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl #else < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; template std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index f1225d86e4..57a86a9dc4 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification, using InOutDataTypeInDevice = typename std:: conditional::value, int8_t, InOutDataType>::type; #else - using InOutDataTypeInDevice = InOutDataType; + using InOutDataTypeInDevice = InOutDataType; #endif using DeviceReduceInstance = diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 1bea1bcf3e..3e3c586dba 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), - {}, + {}, e_device_buf.GetDeviceBuffer(), - {r0_device_buf.GetDeviceBuffer()}, + {r0_device_buf.GetDeviceBuffer()}, M, N, K, StrideA, StrideB, - {}, + {}, StrideE, a_element_op, b_element_op, diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 62295c57eb..42bfea372e 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -207,7 +207,7 @@ int main(int argc, char* argv[]) auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), nullptr, - {}, + {}, c_device_buf.GetDeviceBuffer(), p_reduces, M, @@ -216,9 +216,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - {}, + {}, gemm_element_ops, - {}, + {}, reduce_in_element_ops, reduce_out_element_ops, BatchCount); diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc index 23608a1eea..02b60fe548 100644 --- a/example/27_layernorm2d_fwd/run_layernorm_example.inc +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index cdfd86dff4..c693995140 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -126,10 +126,10 @@ int run(int argc, char* argv[]) if(i < 4) { - std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " - << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " - << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks[" + << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i + << "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i + << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; } switch(init_method) diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index 3756310fd7..9737b0d99b 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index 6a8002025a..1ffbabd04b 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { - bool pass = true; + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index b27358fd9d..06441be860 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp index ffb9f4b584..8f2b7613b5 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index d2337dcda5..26a03f289d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -129,11 +129,11 @@ int main() auto argument_ptr = device_instance.MakeArgumentPointer( out_dev.GetDeviceBuffer(), {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), current_dim, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 6af8ac6488..1823d4fc0a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_params = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index 54f3a78809..b23128a536 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept } template -inline auto is_valid_axes(const Axes& axes) - -> std::enable_if_t, bool> +inline auto +is_valid_axes(const Axes& axes) -> std::enable_if_t, bool> { using std::empty; if(empty(axes)) @@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes) } template -auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< - detail::is_bidirectional_range_v && detail::is_sized_range_v && - detail::is_bidirectional_range_v && detail::is_sized_range_v, - bool> +auto advance_indices(const Shape& shape, Indices& indices) + -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 853ff791a6..ab6f317bc6 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 9431a8cde4..c40447e1f9 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) std::array inputs = {input_dev_buf.GetDeviceBuffer()}; std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), - output_scaled_casted_dev_buf.GetDeviceBuffer()}; + output_scaled_casted_dev_buf.GetDeviceBuffer()}; std::cout << "Input: " << input.mDesc << std::endl; std::cout << "Scale: " << scale << std::endl; @@ -164,8 +164,8 @@ int main(int argc, char* argv[]) auto launch_transpose_scale = [&]() { auto transposeScale = DeviceElementwisePermuteInstance{}; auto argument = transposeScale.MakeArgumentPointer(dims, - {in_strides}, - {out_strides, in_strides}, + {in_strides}, + {out_strides, in_strides}, inputs, outputs, ScalePassThrough{scale}); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 8b88e2482d..e7c1d6f0be 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b_device_buf.GetDeviceBuffer()}, std::array{d_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index eaabccdf2a..ec1b2d6018 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer()}, + b1_device_buf.GetDeviceBuffer()}, std::array{}, e_device_buf.GetDeviceBuffer(), std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp index 6940c20695..f521c51d67 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification, auto device_ew_scale = DeviceElementwiseScale{}; auto scale_invoker = device_ew_scale.MakeInvoker(); auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, - {e_g_n_k_wos_strides}, - {e_g_n_k_wos_strides}, - {conv_device_buf.GetDeviceBuffer()}, - {out_device_buf.GetDeviceBuffer()}, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, scale_convert); if(!device_ew_scale.IsSupportedArgument(scale_argument)) diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc index 1a0b558e2c..f75c01ec61 100644 --- a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() {0, W * C, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 3}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 9e80a2ca35..f78e6e48a5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -357,7 +357,7 @@ int main(int argc, char* argv[]) int n1 = n % NLane; int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); + tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 56d709f41b..7bd628edf2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -24,26 +24,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) @@ -55,81 +56,74 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any DPP examples if DPP_KERNELS not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + #Do not build any DPP examples if DPP_KERNELS not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") message(DEBUG "removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any microscaling examples if gfx950 target is not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + #Do not build any microscaling examples if gfx950 target is not on the list + if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") message(DEBUG "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any FP8 examples if CK_ENABLE_FP8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") + #Do not build any FP8 examples if CK_ENABLE_FP8 not set + if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8") message(DEBUG "removing fp8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any BF8 examples if CK_ENABLE_BF8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") + #Do not build any BF8 examples if CK_ENABLE_BF8 not set + if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8") message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") - if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)") - message(DEBUG "Skipping ${source} example for current target") - list(REMOVE_ITEM FILE_NAME "${source}") + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() endif() - endif() endforeach() #only continue if there are some source files left on the list + set(source_name_list "") + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + list(APPEND source_name_list ${source_name}) + endforeach() if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") + if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 message(DEBUG "trimming targets for ${FILE_NAME}") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -156,71 +150,71 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message(DEBUG "removing example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + set(test 0) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message(DEBUG "removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set + set(source_name_list "") foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() + list(APPEND source_name_list ${source_name}) endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") + if(source_name_list MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 89fbcff40c..30b524d606 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import fnmatch import itertools from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Literal from codegen.cmake_config import * from codegen.cpp_symbol_map import * @@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) }} """ -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - deterministic : str - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 64 == 0' - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - hdim_int = int(hdim) - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) - # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) # GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) # GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) # Is it necessary to distinguish between K0~K4? -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen @@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize: f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad, - deterministic=self.F_deterministic - ) - # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: @@ -477,87 +363,6 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict else: return None -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - if ("wg32" in dropout): - continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - cond &= mode == 'batch' - cond &= deterministic == "f" - if not cond: - continue - # Aiter (mha_bwd) integration - elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= dpad == dvpad - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -616,7 +421,7 @@ std::string fmha_bwd_dot_do_o_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdOGradDotOKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -656,49 +461,6 @@ class FmhaBwdOGradDotOKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) - - return gen - FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -765,7 +527,7 @@ std::string fmha_bwd_convert_dq_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdConvertQGradKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -813,92 +575,256 @@ class FmhaBwdConvertQGradKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 +@dataclass(frozen=True) +class FmhaBwdApiTrait: + idx : int # this is not a tunable, but a counter to differentiate symbol + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : int + dtype : str # data type + mode : str # value from MODE_MAP + tile : FmhaBwdDQDKDVTileSize + mask : str + bias : str + dbias : str + dropout : str + spad : str + spad1 : str # spad for dot/convert kernel + skpad : str + dpad : str + dvpad : str + deterministic : str + mask_impl : str - gen = list() + @property + def bm0(self) -> int: + return self.tile.F_bm0 + @property + def bn0(self) -> int: + return self.tile.F_bn0 + @property + def bhdq(self) -> int: + return self.tile.F_bhdq + @property + def bhdv(self) -> int: + return self.tile.F_bhdv + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' + else: # self.skpad == 'f' and skpad1 == 'f' + return 'a.seqlen_q % 64 == 0' + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + + @property + def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1, + F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + + @property + def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: + return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, + F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, + F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) + + @property + def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad, + F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic) + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: + if filter_list == '': + filter_list = '*@*@*' + filter_list = filter_list.split('@') + filter_list.extend(['*'] * (3 - len(filter_list))) + filter_dot_do_o = filter_list[0] + filter_convert_dq = filter_list[1] + filter_dq_dk_dv = filter_list[2] + + # use dict as ordered set + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + api_pool = FmhaBwdApiPool(mask_impl) for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: + if d is None: continue - for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) + for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)): tile = d[hdim_str][0] - if (mode == "group" and spad == "f"): + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): continue - k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, - F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): + if (spad1 == "f") and (spad == "t" or mode == "group"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if (dpad == "t" or dvpad == "t"): + ppl = d[hdim_str][2] + t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): + continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue + if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + elif receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= deterministic == "f" + if not cond: continue # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # Aiter (mha_varlen_bwd) integration elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # aiter::mha_bwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + gen_dot_do_o[t.dot_do_o_kernel] = True + gen_dq_dk_dv[t.dq_dk_dv_kernel] = True + gen_convert_dq[t.convert_dq_kernel] = True + api_pool.register_dq_dk_dv_traits(t) - return gen - -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + for k in kernels_dot_do_o: + (output_dir / k.filename).write_text(k.template) + for k in kernels_convert_dq: + (output_dir / k.filename).write_text(k.template) + for k in kernels_dq_dk_dv: + (output_dir / k.filename).write_text(k.template) - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] - - with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") +def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: + _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) + with file_path.open("a") as f: + for k in kernels_dot_do_o: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_dq_dk_dv: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_convert_dq: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 06a012d277..730641a6b0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -27,6 +27,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + 192: 192, 256: 256 } @@ -504,11 +505,11 @@ class KernelComponentFactory: return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } elif dtype == 'fp8' or dtype == 'bf8': @@ -532,31 +533,20 @@ class KernelComponentFactory: pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256 and hdim_v == 256: - # if True: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 517e84f380..2e5bc2bd3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -326,6 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -334,7 +337,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16, bf16'] + cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue @@ -350,16 +353,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index edc1532a05..5b35e7f0bd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -637,9 +637,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,9 +656,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -670,7 +670,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -746,6 +746,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -783,7 +786,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -830,6 +833,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Aiter(mha_varlen_fwd) integration if receipt == 200: cond = dtype in ['fp16', 'bf16'] @@ -855,12 +861,11 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) @@ -868,13 +873,12 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index c611618824..0317330511 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -126,9 +126,6 @@ if __name__ == "__main__": filter_list.extend([''] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(',')] - if len(api_list) > 1: - assert optdim_list == [-1] - if args.list_blobs is not None: list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b72485222e..bdd5f2da1b 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index da37159aeb..20cc202176 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -23,7 +23,7 @@ args: -n n dimension (default:2048) -k k dimension (default:64) -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) + -b_layout Tensor B data layout (default: C) -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 80c18cdb87..0d9c2d9957 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -24,7 +24,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 9deccc7f16..1e867afd1a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -114,16 +114,16 @@ template struct GemmConfigComputeV3 : public GemmConfigBase { // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; @@ -241,8 +241,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; @@ -263,8 +263,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; @@ -475,4 +475,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index f57c24f458..34333d5474 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,119 +74,120 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) @@ -220,7 +221,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a auto [result, arg_parser] = create_args(argc, argv); bool preshuffle = GemmConfig::Preshuffle; - if(preshuffle && a_layout != "R" && b_layout != "C") + if(preshuffle && (a_layout != "R" || b_layout != "C")) { throw std::runtime_error( "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index f13a4b693b..7f87c2bc06 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -158,7 +158,7 @@ template -float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); template args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; float ave_time; if(persistent) @@ -315,8 +313,16 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + if constexpr(preshuffle) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } } else if(init_method == 1) { diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c96a470910..6c60f98fa4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,120 +74,121 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp index 6380cd2994..299a2f3444 100644 --- a/example/ck_tile/04_img2col/image_to_column.cpp +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -149,9 +149,17 @@ int main(int argc, char* argv[]) float ave_time = image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel}); - std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + if(config.time_kernel) + { + std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + { + std::cout << "image_to_column: pass, No Perf generated due to config.time_kernel=0" + << std::endl; + } bool pass = true; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 28f4c452bc..688f4f3d50 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel return tmp_1; #else // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, - constexpr index_t kv = Alignment; - constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten = kw * nw * kv; - const index_t kr = a_.k / (k1 * k2); - const index_t nr = a_.n / nw; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; auto tmp = make_naive_tensor_view_packed( p_dst, make_tuple(nr, kr, waveflatten), @@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten_tile = kw * nw * kv; - constexpr index_t nr_tile = NPerBlock / nw; - constexpr index_t kr_tile = KPerBlock / (kw * kv); + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); return make_tile_window(dst_view, make_tuple(number{}, number{}, diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 25598282e3..e0a71452ea 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "0", "cold iter") - .insert("repeat", "1", "hot iter"); + .insert("repeat", "1", "hot iter") + .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser) false, // kSaveInvRms false, // kSaveUnquant kTwoPass, - ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add - ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant + ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add + ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant + static_cast( + USEModelSensitive)>; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; + using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass; + + using Pipeline = + std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm == + ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL || + PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline + std::conditional_t, // kUseModelSensitiveRMSNorm + // == 0 + T5PassPipeline>; using Default2DEpilogueProblem = ck_tile:: Default2DEpilogueProblem; @@ -170,9 +183,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride - << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride + << ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush + << std::endl; } return pass; @@ -184,10 +197,19 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); + const std::string data_type = arg_parser.get_str("prec"); + const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm + { + return run(arg_parser) ? 0 : -2; + } + else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm + { + return run(arg_parser) ? 0 : -2; + } } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 4296b7373e..b0ba400af1 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -65,7 +65,8 @@ template + ck_tile::index_t kFusedQuant_ = 0, + ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0> struct rmsnorm2d_fwd_traits_ { using XDataType = ck_tile::remove_cvref_t; @@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_ static constexpr bool kSaveInvRms = kSaveInvRms_; static constexpr bool kSaveUnquant = kSaveUnquant_; static constexpr bool kTwoPass = kTwoPass_; - static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; - static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; + static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_; }; template + int kFusedQuant_, + int kUseModelSensitiveRMSNorm_> using traits_ = rmsnorm2d_fwd_traits_; + kFusedQuant_, + kUseModelSensitiveRMSNorm_>; """ API_COMMON_HEADER = """ @@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a) Traits_::kSaveUnquant, Traits_::kTwoPass, static_cast(Traits_::kFusedAdd), - static_cast(Traits_::kFusedQuant)>; + static_cast(Traits_::kFusedQuant), + static_cast(Traits_::kUseModelSensitiveRMSNorm)>; using PipelineProblem = ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, @@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a) using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; + using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass; + + using Pipeline = std::conditional_t< + (Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline + std::conditional_t, // kUseModelSensitiveRMSNorm == 0 + T5PassPipeline + >; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; @@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_kTwoPass : bool F_kFusedAdd : int F_kFusedQuant : int + F_use_model_sensitive_rmsnorm : int @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}' return t_ # string when calling this kernel @@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_add : int F_sweep : int F_saveunquant : bool + F_use_model_sensitive_rmsnorm : int instance_list : List[Any] # List[h_traits] @property @@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] if self.F_saveunquant: nnn = nnn + '_saveunquant' + if self.F_use_model_sensitive_rmsnorm == 0: + nnn = nnn + '_nsm' + elif self.F_use_model_sensitive_rmsnorm == 1: + nnn = nnn + '_t5ml' return nnn @property @@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, elif ins.F_kFusedQuant == 2: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) - _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format( f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, - f_sweep_cond = _sweep_cond) + f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm) inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str @@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant bool_list = [False, True] - # rm rn tm tn vn pd mv unquant 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)], - '640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]} + h_trait_dicts = { + 0: { + # rm rn tm tn vn pd mv unquant 2p add sweep srm + '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)], + '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)] + }, + 1: { + # rm rn tm tn vn pd mv unquant 2p add sweep srm + '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)], + '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] + } + } + total_blob = list() - for hs_key in h_trait_dict: - hs = h_trait_dict[hs_key] - current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): - prec_i, prec_o = dtype.split(',') - scale_sm, scale_y = scale_type.split(',') - if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: - continue # skip non dynamic quant case - if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': - continue - if (fused_quant == 0 and save_unquant == True): - continue # save_unquant should always be false when there is no quant enabled - current_hs = list() - for chs_ in hs: - h_ = copy.copy(chs_) # copy the base instance out - h_.F_XDataType = prec_i - h_.F_YDataType = prec_o - h_.F_SmoothScaleDataType = scale_sm - h_.F_YScaleDataType = scale_y - h_.F_UnquantYDataType = prec_i - h_.F_kFusedAdd = fused_add - h_.F_kFusedQuant = fused_quant - h_.F_kSaveUnquant = save_unquant - current_hs.append(h_) # + "\n" - #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ - current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs)) + + for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive + current_trait_dict = h_trait_dicts[model_sensitive_flag] + for hs_key in current_trait_dict: + hs = current_trait_dict[hs_key] + current_n = hs_key + for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): + prec_i, prec_o = dtype.split(',') + scale_sm, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: + continue # skip non dynamic quant case + if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': + continue + if (fused_quant == 0 and save_unquant == True): + continue # save_unquant should always be false when there is no quant enabled + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_UnquantYDataType = prec_i + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + h_.F_kSaveUnquant = save_unquant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs)) return total_blob def list_blobs(self) -> None: diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index d5be4384ab..751b868411 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[]) .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -66,15 +67,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -191,13 +193,19 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride - << ", yr_stride:" << yr_stride << std::flush; + << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; - rmsnorm2d_fwd_traits traits{ - prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant}; + rmsnorm2d_fwd_traits traits{prec_i, + prec_o, + prec_sm, + prec_sy, + SaveRms, + SaveUnquant, + fused_add, + fused_quant, + use_model_sensitive_rmsnorm}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index bb4a2f5ef4..c1090ed28b 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits bool save_unquant; int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant + + int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation }; float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh index 7b9d0820fd..bc4362c105 100755 --- a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh @@ -1,37 +1,74 @@ #!/bin/sh EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +# 0: for no specific RMSNorm +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 + +# 1: for T5-like RMSNorm +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 2bad7a00ea..1c79dafadd 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm +for s in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 +done done done done @@ -36,8 +39,11 @@ done for fquant in "" for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm +for s in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 done done done +done diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 06c04b763e..1cd375d0f5 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -105,8 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index c43d9c9a2e..449bc17e04 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -256,8 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 20e1591516..5fcacacee8 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -216,10 +216,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride - << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush - << std::endl; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n + << ", x_stride:" << x_stride << ", y_stride:" << y_stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } return pass; diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index f3ba587132..02ab1cd9b1 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -93,9 +93,8 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); smscale_buf.ToDevice(smscale_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride - << std::flush; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index f139081cd4..e9b4ea5cd3 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[]) .insert("e", "8", "number of num_experts") .insert("k", "4", "topk") .insert("unit", "32", "unit_size") +#if MOE_SORTING_FMOE_2D_BUF + .insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf") + .insert( + "moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...") +#else .insert("moe_buf_size", "0", "moe_buf_size") +#endif + .insert("ci", + "1", + "clear workspace inside API or not(if \"0\", require manually clear outside)") + .insert( + "dispatch", + "0", + "dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel") .insert("local_eid", "-1", "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" @@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args) int topk = args.get_int("k"); int seed = args.get_int("seed"); int unit_size = args.get_int("unit"); - int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); - int kname = args.get_int("kname"); - int warmup = args.get_int("warmup"); - int repeat = args.get_int("repeat"); +#if MOE_SORTING_FMOE_2D_BUF + int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim"); + int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes"); +#else + int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); +#endif + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + bool clear_inside = args.get_int("ci") != 0; + int dispatch_policy = args.get_int("dispatch"); int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); - ck_tile::HostTensor sorted_id_cnt_host({1}, {1}); + // for simplicity, below buffer allocate 2 dword + ck_tile::HostTensor sorted_id_cnt_host({2}, {1}); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_host( + {static_cast(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim * + moe_buf_elem_bytes}); + auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#else ck_tile::HostTensor moe_buf_host({moe_buf_size}); + auto moe_buf_bytes = moe_buf_size == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#endif ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#else ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#endif topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); @@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); - if(moe_buf_size > 0) + if(moe_buf_bytes > 0) { moe_buf_dev.ToDevice(moe_buf_host.data()); } @@ -184,12 +219,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk); + ck_tile::index_t workspace_size = + moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); - if(workspace_size != 0) + if(workspace_size != 0 && clear_inside == false) moe_sorting_ws.SetZero(); // note, clear here!!!! - moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; + moe_sorting_trait trait{ + index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), @@ -200,13 +237,19 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, num_experts, topk, - static_cast(moe_buf_size * sizeof(float))}; +#if MOE_SORTING_FMOE_2D_BUF + moe_buf_interm_dim, + moe_buf_elem_bytes +#else + static_cast(moe_buf_size * sizeof(float)) +#endif + }; ck_tile::stream_config sc{nullptr, true, @@ -219,7 +262,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) #if 0 { - ck_tile::HostTensor ws_host({workspace_size}, {1}); + ck_tile::HostTensor ws_host({workspace_size}, {1}); moe_sorting_ws.FromDevice(ws_host.data()); int * p_mesh = reinterpret_cast(ws_host.data()); @@ -268,7 +311,12 @@ bool test_moe_sorting(ck_tile::ArgParser args) } #endif - printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens); + printf("[%s|%s|%s|%d]tokens:%d", + index_prec.c_str(), + weight_prec.c_str(), + workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"), + dispatch_policy, + tokens); if(is_local_token) { printf("(%d)", local_tokens); @@ -280,6 +328,19 @@ bool test_moe_sorting(ck_tile::ArgParser args) printf("local_eid:%s, ", args.get_str("local_eid").c_str()); } + if(moe_buf_bytes > 0) + { +#if MOE_SORTING_FMOE_2D_BUF + printf("moe_buf:%lu(%d,%d), ", + static_cast(moe_buf_bytes), + moe_buf_interm_dim, + moe_buf_elem_bytes); +#else + + printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); +#endif + } + if(ms < 0) printf("not supported\n"); else @@ -294,7 +355,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_weights_dev.FromDevice(sorted_weights_host.data()); sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); - if(moe_buf_size > 0) + if(moe_buf_bytes > 0) { moe_buf_dev.FromDevice(moe_buf_host.data()); } @@ -340,6 +401,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6); + // if(is_local_token) + { + auto t_ = is_local_token ? local_tokens : tokens; + bool _f = t_ == sorted_id_cnt_host.mData[1]; + rtn &= _f; + if(!_f) + { + printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]); + } + } } else { @@ -347,9 +418,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) rtn = false; } - if(moe_buf_size) + if(moe_buf_bytes) { +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_ref({moe_buf_bytes}); +#else ck_tile::HostTensor moe_buf_ref({moe_buf_size}); +#endif rtn &= ck_tile::check_err( moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 0899fefcfc..a71c5e51a6 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0) { return moe_sorting_mp(t, a, s); } @@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ @@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ @@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ @@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = ck_tile::launch_kernel( \ s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ @@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { bool is_local_token = a.p_local_tokens != nullptr; @@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > ck_tile::get_smem_capacity()) { @@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co if(t.local_expert_masking) { float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, MOE_SORTING_MP_0(ms_index_t, 1, true), MOE_SORTING_MP_1(ms_index_t, 1, true), MOE_SORTING_MP_2(ms_index_t, 1, true), @@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co else { float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, MOE_SORTING_MP_0(ms_index_t, 1, false), MOE_SORTING_MP_1(ms_index_t, 1, false), MOE_SORTING_MP_2(ms_index_t, 1, false), @@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co return -1; } -int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk) +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0fe8d81e70..6c6cd0f4fa 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,8 +10,14 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float - bool local_expert_masking; // if mask experts as local expert + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert + bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of + // it inside API) + int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 - + // always use mp kernel NOTE: moe_sorting_get_workspace_size() need use + // same dispatch_policy value. it will be undefined behavior if ppl using + // different value when get ws and call the kernel }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs @@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs // if return non zero, means need workspace, you need to allocate a GPU buffer // and set to moe_sorting_args.p_ws // NOTE: workspace size are required to clear zero before use the API -int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk); +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 63bc0acceb..2c245f6e7f 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -1,7 +1,9 @@ # #!/bin/sh EXE=./build/bin/tile_example_moe_sorting +MOE_BUF="12" +if [ "x$MOE_BUF" = "x1" ] ; then $EXE -t=80 -e=17 -moe_buf_size=16 $EXE -t=111 -e=117 -moe_buf_size=4 $EXE -t=1000 -e=55 -moe_buf_size=1024 @@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1 $EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33 $EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33 $EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940 +else +$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4 +$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4 +$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1 +$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2 +$EXE -t=175 -e=64 -k=8 +$EXE -t=65 -e=8 -k=2 +$EXE -t=1 -e=25 +$EXE -t=31 -e=19 -k=15 +$EXE -t=81 -e=37 -k=7 +$EXE -t=23 -e=1 -k=1 +$EXE -t=127 -e=99 -k=19 +$EXE -t=71 -e=11 -k=11 +$EXE -t=1 -e=1 -k=1 +$EXE -t=99 -e=2 -k=1 +$EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 +$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 +$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1 +$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840 +$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1 +$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840 +$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840 +$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4 +$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145 +$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99 +$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244 +$EXE -t=536 -local_t=345 -e=802 -k=99 +$EXE -t=331 -local_t=39 -e=83 -k=33 +$EXE -t=765 -local_t=654 -e=783 -k=8 +$EXE -t=23 -local_t=9 -e=1 -k=1 +$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33 +$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33 +$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1 + +fi diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp index dc5b397c85..848fb87dcf 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -124,9 +124,9 @@ bool run(const ck_tile::ArgParser& arg_parser) smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << prec_i << "-" << prec_o << "]" - << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride - << ", experts:" << experts << ", topk:" << topk << std::flush; + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens + << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts + << ", topk:" << topk << std::flush; moe_smoothquant_traits traits{prec_i, prec_o}; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 27274878a2..43ae5cf677 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -6,7 +6,8 @@ int fused_moe_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size( + tokens, num_experts, topk, 0 /*dispatch policy*/); } float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) @@ -39,8 +40,13 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf a.block_m, // index_t unit_size; a.num_experts, // index_t num_experts; a.topk, // index_t topk; +#if MOE_SORTING_FMOE_2D_BUF + a.stride_token, + o_data_bytes, +#else static_cast(a.num_tokens) * a.stride_token * o_data_bytes // index_t moe_buf_bytes; +#endif }; auto t1 = fused_moegemm_traits{t.prec_i, diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 343ddbed13..6e54df9fde 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -16,11 +16,11 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) { using f_traits = ck_tile::FusedMoeGemmTraits; using f_shape = ck_tile::FusedMoeGemmShape; + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0, + typename Ts_::BlockTile_1, + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0>; constexpr auto get_activation_ = []() { if constexpr(Ts_::Activation == 0) diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index f745284f3e..5f87393a0a 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -204,11 +204,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -222,11 +222,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -240,11 +240,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -258,11 +258,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -277,11 +277,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t, int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size( + tokens, num_experts, topk, 0 /*dispatch policy*/); } diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index d9950426a2..e4d87e5fef 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -218,8 +218,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::string(", st:") + std::to_string(stride); }(); - std::cout << "[" << api_str << "|" << prec_str << "]" - << " t:" << tokens; + std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens; if(is_local_token) { @@ -399,7 +398,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr ck_tile::index_t workspace_size = - ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); + ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 7d5e1910dd..6d26cfe675 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -50,21 +50,20 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; - args.batch_stride_A = batch_stride_A; - args.batch_stride_B = batch_stride_B; - args.batch_stride_E = batch_stride_C; - args.batch_count = batch_count; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count}; float ave_time = batched_gemm& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index c4e83617d3..74efb1bdeb 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -54,7 +54,7 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -using grouped_gemm_kargs = ck_tile::GemmHostArgs; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp index 4107181520..897952f03c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -138,10 +138,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 5ed1219731..fa7f1a31c1 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -83,18 +83,18 @@ float invoke_gemm(int n_warmup, const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, - arg.b_ptr, - {}, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - {}, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, + {arg.b_ptr}, + {/*arg.ds_ptr*/}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + {/*arg.stride_Ds*/}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, @@ -216,9 +216,9 @@ int run_grouped_gemm_example_with_layouts(int argc, c_m_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); - std::cout << "gemm[" << i << "]" - << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc - << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << std::endl; ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); @@ -240,7 +240,7 @@ int run_grouped_gemm_example_with_layouts(int argc, void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } invoke_gemm) + { + return "bf16"; + } else { return "unknown"; diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 6c5ca08426..8971871c14 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -157,7 +157,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& UniversalGemmProblem::TransposeC, memory_operation>>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); @@ -170,10 +170,9 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel( diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index 3ce3965e56..87b9592553 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -64,7 +64,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -using gemm_multi_d_kargs = ck_tile::GemmHostArgs; +using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs; template + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel_preprocess( + s, + Kernel::Preprocess(kargs, s), + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +#include "run_grouped_convolution_bwd_weight_example.inc" + +template +int run_grouped_conv_bwd_weight_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index 685fdccde2..ce19c77bc1 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -23,7 +23,7 @@ template , typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) +float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) { constexpr int kBlockPerCu = 1; @@ -97,7 +97,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile:: ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(kargs); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -129,7 +129,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile:: ck_tile::memory_operation_enum::set>{}); } -#include "run_grouped_convolution_example.inc" +#include "run_grouped_convolution_fwd_example.inc" template int run_grouped_conv_fwd_example_prec_type( @@ -185,7 +185,7 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) std::string data_type = arg_parser.get_str("prec"); std::string in_layout = arg_parser.get_str("in_layout"); - std::string wei_layout = arg_parser.get_str("weight_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); std::string out_layout = arg_parser.get_str("out_layout"); if(data_type == "fp16") diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index cc8d365b18..f3a7a60fd9 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -12,6 +12,28 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +template +auto calculate_rtol_atol(const ck_tile::index_t GemmK, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(GemmK, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = + ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + ck_tile::index_t fill_spatial_dimensions(std::vector& filter_spatial_lengths, std::vector& image_spatial_lengths, std::vector& strides, @@ -90,7 +112,7 @@ auto create_args(int argc, char* argv[]) .insert("rpad_w", "0", "right pad for w dimension") .insert("in_layout", "NHWGC", "Input image layout - NHWGC by default") - .insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default") + .insert("wei_layout", "GKYXC", "Weight layout - GKYXC by default") .insert("out_layout", "NHWGK", "Output image layout - NHWGK by default") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") @@ -105,4 +127,5 @@ auto create_args(int argc, char* argv[]) } // host API -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s); +float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc new file mode 100644 index 0000000000..637ea2fbfb --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, + int n_warmup, + int n_repeat) +{ + float ave_time = grouped_conv_bwd_weight( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_bwd_weight_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(input); + ck_tile::FillUniformDistribution{-1.f, 1.f}(output); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(input); + ck_tile::FillMonotonicSeq{}(output); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(input); + ck_tile::FillUniformDistribution{1.f, 1.f}(output); + } + else + { + input.SetZero(); + output.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.ToDevice(input.data()); + weight_dev_buf.SetZero(); + output_dev_buf.ToDevice(output.data()); + + ck_tile::GroupedConvBwdWeightHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Fwd kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); + + weight_dev_buf.FromDevice(weight.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor weight_host_ref(wei_g_k_c_xs_desc); + weight_host_ref.SetZero(); + + ck_tile:: + reference_grouped_conv_bwd_weight( + input, + weight_host_ref, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(weight, + weight_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc similarity index 81% rename from example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc rename to example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc index ed72eb354d..3532e343bb 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc @@ -2,28 +2,6 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -template -auto calculate_rtol_atol(const ck_tile::index_t GemmK, - const ck_tile::index_t kbatch, - const float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(GemmK, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); - // Calculate error due to split_k accumulation - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = - ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); -} - template -float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat) +float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + int n_warmup, + int n_repeat) { float ave_time = grouped_conv_fwd +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + + // If stride is negative (default -1), set it to N, assuming a dense row-major layout. + if(stride < 0) + stride = N; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + if(stride < N) + { + throw std::runtime_error("stride must be >= N"); + } + + // Define type aliases for clarity. + // XDataType: Data type of the input tensors. + // ComputeDataType: Data type used for intermediate computations (often float for precision). + // YDataType: Data type of the output tensor. + // XElementwiseOperation: The specific elementwise operation to perform (e.g., Add, Mul). + using XDataType = DataType; + using ComputeDataType = + float; // Using float for intermediate calculations can improve numerical stability. + using YDataType = DataType; + using XElementwiseOperation = ck_tile::element_wise::Add; + + // 1. Initialize the input data on the host (CPU). + // HostTensor is a utility to manage tensor data on the CPU. + // The first argument is the shape (dimensions) of the tensor {M, N}. + // The second argument is the strides {stride, 1} for row-major layout. + // 'x_host_a' and 'x_host_b' are the two input tensors for the elementwise operation. + ck_tile::HostTensor x_host_a({M, N}, {stride, 1}); + ck_tile::HostTensor x_host_b({M, N}, {stride, 1}); + ck_tile::HostTensor y_host({M, N}, {stride, 1}); + ck_tile::HostTensor y_validation({M, N}, {stride, 1}); + + std::vector shape = {M, N}; + + // Fill the host tensors with random data. + // FillUniformDistribution populates the tensor with values from a uniform distribution, + // within an interval. + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_b); + + // 2. Create device memory buffers + // DeviceMem allocates memory on the GPU. + // The size is determined by the total number of elements and the size of DataType. + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + + // Copy data from host input tensors to device buffers. + x_buf_a.ToDevice(x_host_a.data()); + x_buf_b.ToDevice(x_host_b.data()); + + // 3. Configure the kernel execution parameters. + // Dividing the problem into blocktile, blockwarp and warptile + // The blocktile is the size of the tile processed by a single work group (also called thread + // block). The warptile is the size of the tile processed by a single wavefront (also called + // warp). The vector is the size of the tile processed by a single work item (also called + // thread). The problem is divided into blocks of size BlockTile. Each block is further divided + // into wavefronts of size WarpTile. Each wavefront is composed of 64 work items (on AMD; 32 + // threads on NVIDIA). Each work item in a wavefront processes one vector's worth of elements. + // Note that WarpTile/Vector should be 64 for CDNA (because there are 64 work items per + // wavefront). Vector size is set to be 16 / sizeof(ComputeDataType), to maximize vectorization. + using BlockTile = ck_tile::sequence<2048>; // How many elements are handled by a block tile (the + // tensor is divided into blocks of this size) + using BlockWarps = ck_tile::sequence<8>; // How many concurrent wavefronts are in a block (each + // wavefront will cover some part of the block tile) + + // WarpTile: Defines the size of the data sub-tile processed by a single wavefront. + // This should be consistent with BlockTile and BlockWarps. + // If BlockTile is 2048 and BlockWarps is 8, then WarpTile could be 2048/8 = 256. + // However, this example uses 64, meaning each wavefront processes 64 elements, and multiple + // such wavefront operations might be needed to cover the BlockTile, or the BlockTile is + // distributed differently. + // The current configuration (BlockTile=2048, BlockWarps=8, WarpTile=64) implies that + // each wavefront processes 64 elements, and 8 wavefronts process 8*64 = 512 elements + // concurrently. Since 512 is not equal to 2048, it means that warptile(s) will need to iterate + // over multiple times over different set of elements to cover the entire BlockTile. + using WarpTile = ck_tile::sequence<64>; + + // 4. Create the kernel + + // ElementWiseShape bundles these tiling parameters. + // It calculates derived properties like threads per wavefront, repeats, vectorization and total + // block size. + using Shape = ck_tile::ElementWiseShape; + + // ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel: + // - Data types (input, compute, output). + // - Shape traits (tiling configuration). + // - The specific elementwise operation (e.g., Add). + using Problem = ck_tile::ElementWisePipelineProblem; + + // ElementWiseKernel refers to the GPU kernel class + using Kernel = ck_tile::ElementWiseKernel; + + // Compute flattened size + ck_tile::index_t total_elements = 1; + for(auto d : shape) + total_elements *= d; + + // kBlockSize: The number of work items in a GPU workgroup (thread block). + // This is often a multiple of the wavefront size, 64 on CDNA. + // Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize. + // Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512). + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + + // kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU). + // This can influence occupancy and performance. + constexpr ck_tile::index_t kBlockPerCu = 1; + + // kGridSize: Calculates the total number of workgroups required to process all elements. + // Each workgroup is responsible for 'elements_per_block' elements. + // To ensure all elements are covered, especially when 'total_elements' is not perfectly + // divisible by 'elements_per_block', using ceiling division. + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer()), + static_cast(x_buf_b.GetDeviceBuffer())); + + auto input_size = ck_tile::make_tuple(M, N); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + ck_tile::reference_binary_elementwise( + x_host_a, x_host_b, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp new file mode 100644 index 0000000000..f18a910813 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("dim0", "4", "dimension 0") + .insert("dim1", "16", "dimension 1") + .insert("dim2", "32", "dimension 2") + .insert("dim3", "32", "dimension 3") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t D0 = arg_parser.get_int("dim0"); + ck_tile::index_t D1 = arg_parser.get_int("dim1"); + ck_tile::index_t D2 = arg_parser.get_int("dim2"); + ck_tile::index_t D3 = arg_parser.get_int("dim3"); + + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + using XDataType = DataType; + using ComputeDataType = + float; // Using float for intermediate calculations can improve numerical stability. + using YDataType = DataType; + using XElementwiseOperation = ck_tile::element_wise::Add; + + // Initialize the input data on the host (CPU). + std::vector problem_shape = {D0, D1, D2, D3}; + + std::vector host_strides(4); + host_strides[3] = 1; + host_strides[2] = problem_shape[3]; + host_strides[1] = problem_shape[2] * problem_shape[3]; + host_strides[0] = problem_shape[1] * problem_shape[2] * problem_shape[3]; + + ck_tile::HostTensor x_host_a(problem_shape, host_strides); + ck_tile::HostTensor x_host_b(problem_shape, host_strides); + ck_tile::HostTensor y_host(problem_shape, host_strides); + ck_tile::HostTensor y_validation(problem_shape, host_strides); + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + ck_tile::FillUniformDistribution{2.f, 10.f}(x_host_b); + + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + + x_buf_a.ToDevice(x_host_a.data()); + x_buf_b.ToDevice(x_host_b.data()); + + using BlockTile = ck_tile::sequence<256>; + using BlockWarps = ck_tile::sequence<1>; + using WarpTile = ck_tile::sequence<256>; + + using Shape = ck_tile::ElementWiseShape; + + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + for(auto d : problem_shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + + constexpr ck_tile::index_t kBlockPerCu = 2; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer()), + static_cast(x_buf_b.GetDeviceBuffer())); + + auto problem_shape_tuple = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + + auto strides_tuple = + ck_tile::make_tuple(host_strides[0], host_strides[1], host_strides[2], host_strides[3]); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(problem_shape_tuple)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // Run the kernel + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + problem_shape_tuple, // ck_tile::tuple + strides_tuple, // ck_tile::tuple for input strides + strides_tuple, // ck_tile::tuple for output strides + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + ck_tile::reference_binary_elementwise( + x_host_a, x_host_b, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp new file mode 100644 index 0000000000..affc337c38 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_transpose.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "1024", "m dimension of input") + .insert("n", "1024", "n dimension of input") + .insert("stride_in", "-1", "stride for input M dim, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride_in = arg_parser.get_int("stride_in"); + + if(stride_in < 0) + stride_in = N; // Dense input: stride for M dim is N + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + if(stride_in < N) + { + throw std::runtime_error("stride_in must be >= N"); + } + + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + // Use PassThrough operation for transposition (data is moved, not changed) + using XElementwiseOperation = ck_tile::element_wise::PassThrough; + + // 1. Initialize the input data on the host (CPU). + // Input x_host_a: M x N + // Output y_host: N x M (transposed) + ck_tile::HostTensor x_host_a({M, N}, {stride_in, 1}); + // Output tensor y_host will have dimensions N x M. + // Assuming dense output, its stride for the N dimension will be M. + ck_tile::index_t stride_out_dim0 = M; + ck_tile::HostTensor y_host({N, M}, {stride_out_dim0, 1}); + ck_tile::HostTensor y_validation({N, M}, {stride_out_dim0, 1}); + + // The logical shape for the element-wise operation kernel is based on the input tensor's + // elements. + std::vector op_shape_vec = {M, N}; + auto op_lengths = ck_tile::make_tuple(M, N); // Lens for the kernel + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + + // 2. Create device memory buffers + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); // y_host is N x M + + x_buf_a.ToDevice(x_host_a.data()); + + // 3. Configure the kernel execution parameters. + using BlockTile = ck_tile::sequence<1024>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using Shape = ck_tile::ElementWiseShape; + + // Problem definition for a single input tensor + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = M * N; + + constexpr ck_tile::index_t kBlockSize = 64 * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "Input M=" << M << ", N=" << N << ", StrideIn=" << stride_in << std::endl; + std::cout << "Output N=" << N << ", M=" << M << ", StrideOut=" << stride_out_dim0 << std::endl; + std::cout << "Grid size = " << kGridSize << ", BlockSize = " << kBlockSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + // Input tensors tuple (single input) + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer())); + // Input strides tuple (tuple of tuples, one for each input) + auto input_strides = ck_tile::make_tuple(stride_in, 1); + // Output strides (for N x M tensor, dense) + auto output_strides = ck_tile::make_tuple(1, stride_out_dim0); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(op_lengths)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // Shared memory + op_lengths, // Logical dimensions for the operation (M, N) + input_strides, // Strides for input tensor(s) + output_strides, // Strides for output tensor (N, M) + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); // Copy result from device to y_validation + ck_tile::reference_transpose_elementwise( + x_host_a, y_host); // Compute reference on host + pass = ck_tile::check_err( + y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + std::cerr << "Unsupported data type: " << data_type << std::endl; + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp new file mode 100644 index 0000000000..147dfd3424 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "1024", "m dimension") + .insert("n", "1024", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = N; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= N); + + using XDataType = DataType; + using YDataType = DataType; + using ComputeDataType = float; + using XElementwiseOperation = ck_tile::element_wise::UnarySquare; + + // 1. Initialize the input data on the host + ck_tile::HostTensor x_host_a({M, N}, {stride, 1}); + ck_tile::HostTensor y_host({M, N}, {stride, 1}); + ck_tile::HostTensor y_validation({M, N}, {stride, 1}); + + std::vector shape = {M, N}; + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + + // 2. Create device memory buffers and copy input data from host to device + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + x_buf_a.ToDevice(x_host_a.data()); + + // 3. Create the kernel + + // Dividing the problem into blocktile, warptile, and vector + using BlockTile = ck_tile::sequence<2048>; // Size of the block tile (Entire problem is divided + // into blocks of this size) + using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp + // will cover some part of blockTile) + using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp + + using Shape = ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + // Compute flattened size + ck_tile::index_t total_elements = 1; + for(auto d : shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer())); + auto input_size = ck_tile::make_tuple(M, N); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + + auto op = [](const auto& v0) { return v0 * v0; }; + + ck_tile::reference_unary_elementwise(x_host_a, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 1eb0445c84..1f0f0b9bc1 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -2,41 +2,93 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "batched_transpose_example.hpp" -template +namespace { + +template +struct kernel_traits; + +template <> +struct kernel_traits<0> +{ + template + using Problem = + ck_tile::BatchedTransposeProblem; + using Policy = ck_tile::BatchedTransposePolicy; + template + using Pipeline = + ck_tile::BatchedTransposePipeline, + Policy>; +}; + +template <> +struct kernel_traits<1> +{ + template + using Problem = + ck_tile::BatchedTransposeLdsProblem; + using Policy = ck_tile::BatchedTransposeLdsPolicy; + template + using Pipeline = ck_tile::BatchedTransposeLdsPipeline< + Problem, + Policy>; +}; +} // namespace + +template +struct BatchedTransposeConfig +{ + using InputType = InputType_; + static constexpr ck_tile::index_t kBlockX = BlockX_; + static constexpr ck_tile::index_t kBlockY = BlockY_; + static constexpr ck_tile::index_t kNumWarpsX = NumWarpsX_; + static constexpr ck_tile::index_t kNumWarpsY = NumWarpsY_; + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr ck_tile::index_t kPipelineId = PipelineId_; +}; + +template float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) { uint32_t dim_stride = a.height * a.width; a.dim_stride = dim_stride; - a.dim_block_h = block_y; - a.dim_block_w = block_x; + a.dim_block_h = Config::kBlockY; + a.dim_block_w = Config::kBlockX; - using block_tile = ck_tile::sequence; - using warp_tile = ck_tile::sequence; - using thread_tile = ck_tile::sequence; - - using ts_problem = - ck_tile::BatchedTransposeProblem; - using ts_pipeline = ck_tile::BatchedTransposePipeline; - - using kernel = ck_tile::BatchedTransposeKernel; + // TODO: this is fragile and slow to compile + using kernel = ck_tile::BatchedTransposeKernel< + typename kernel_traits::template Pipeline< + typename Config::InputType, + ck_tile::sequence, + ck_tile::sequence, + Config::kPadM, + Config::kPadN>>; auto kargs = kernel::MakeKargs(a); const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); - printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); - printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); - printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + printf("Pipeline: %d\n", Config::kPipelineId); + printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z); + printf("Block: x=%u y=%u z=%u\n", blocks.x, blocks.y, blocks.z); + printf( + "Host args: batch=%d, height=%d, width=%d, dim_stride=%d, dim_block_h=%d, dim_block_w=%d\n", + a.batch, + a.height, + a.width, + a.dim_stride, + a.dim_block_h, + a.dim_block_w); + printf("kargs: kargs.batch=%d kargs.height=%d kargs.width=%d kargs.dim_stride=%d\n", kargs.batch, kargs.height, kargs.width, @@ -52,22 +104,29 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con return ave_time; } -// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) +// Param Comb: type_size, block_x & y, WarpNum_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 1) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 1) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ - static float \ - transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN, PIPE) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN##_v##PIPE( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch< \ + BatchedTransposeConfig>(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -76,38 +135,78 @@ float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) { - if(t.type == "fp8") + if(t.pipeline == "0") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v0(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v0(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v0(a, s); + } } } - else if(t.type == "fp16") + else if(t.pipeline == "1") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v1(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); - } - } - else if(t.type == "bf16") - { - if(a.height % 64 == 0 && a.width % 64 == 0) - { - return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); - } - else - { - return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v1(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v1(a, s); + } } } + return -1; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 33b6f0eacf..571386694b 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -102,7 +102,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "t to 1 will print kernel name"); + .insert("kname", "0", "t to 1 will print kernel name") + .insert("pipeline", "0", "0: no LDS usage, 1: LDS-accelerated (gfx950)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -121,6 +122,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); + std::string pipeline = args.get_str("pipeline"); int seed = args.get_int("seed"); int dim_in[4], dim_out[4]; @@ -166,7 +168,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - auto trait = batched_transpose_trait{prec, layout_in}; + auto trait = batched_transpose_trait{prec, layout_in, pipeline}; uint32_t height = nchw2nhwc ? C : H * W; uint32_t width = nchw2nhwc ? H * W : C; @@ -185,17 +187,15 @@ bool run_batched_transpose(ck_tile::ArgParser args) auto ms = batched_transpose(trait, karg, sc); - std::size_t num_operations = N * C * H * (W - 1); - std::size_t num_bytes = N * C * H * W * sizeof(Type); + std::size_t num_bytes = N * C * H * W * sizeof(Type) * 2; // read + written - float ave_time = ms * 1E-3; float gb_per_sec = num_bytes / ms * 1.E-6; - float tflops = static_cast(num_operations) / ms * 1.E-6; std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out - << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" - << gb_per_sec << " GB/s, " << std::endl; + << " : " << std::endl + << ms << " ms " << std::endl + << gb_per_sec << " GB/s " << std::endl; printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", prec.c_str(), diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp index 487ddc17b2..c37dbed4b3 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -14,6 +14,7 @@ struct batched_transpose_trait { std::string type; std::string layout; + std::string pipeline; }; struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh index dde646eb2a..f19242af28 100755 --- a/example/ck_tile/35_batched_transpose/script/perf_test.sh +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -5,10 +5,14 @@ EXE=./build/bin/tile_example_batched_transpose +for C in "64" "256" "1024" "4096" "16384"; do +for W in "64" "256" "1024" "4096" "16384"; do for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' +for pipeline in "0" "1"; do + +$EXE -pipeline=$pipeline -pr=$pr -N=1 -C=$C -H=1 -W=$W -layout_in='NCHW' -layout_out='NHWC' done +done +done +done \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index 5ba2743364..a8bd692183 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -6,25 +6,27 @@ EXE=./build/bin/tile_example_batched_transpose for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' +for pipeline in "0" "1"; do +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' done +done diff --git a/example/ck_tile/36_copy/CMakeLists.txt b/example/ck_tile/36_copy/CMakeLists.txt deleted file mode 100644 index d1b9ba923c..0000000000 --- a/example/ck_tile/36_copy/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp) -target_compile_options(test_copy_kernel PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) \ No newline at end of file diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp deleted file mode 100644 index 4123408453..0000000000 --- a/example/ck_tile/36_copy/test_copy.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/host.hpp" -#include -#include "test_copy.hpp" - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "64", "m dimension") - .insert("n", "8", "n dimension") - .insert("id", "0", "warp to use") - .insert("v", "1", "cpu validation or not") - .insert("prec", "fp16", "precision") - .insert("warmup", "50", "cold iter") - .insert("repeat", "100", "hot iter"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -bool run(const ck_tile::ArgParser& arg_parser) -{ - using XDataType = DataType; - using YDataType = DataType; - - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t warp_id = arg_parser.get_int("id"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - - ck_tile::HostTensor x_host({m, n}); - ck_tile::HostTensor y_host_ref({m, n}); - ck_tile::HostTensor y_host_dev({m, n}); - - // ck_tile::FillConstant{1.f}(x_host); - ck_tile::half_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - x_host(i, j) = value++; - } - } - - ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); - - x_buf.ToDevice(x_host.data()); - - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, 2>; - constexpr bool AsyncCopy = true; - - ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); - std::cout << "grid size " << kGridSize << std::endl; - - using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; - using Kernel = ck_tile::TileCopy; - - constexpr ck_tile::index_t kBlockSize = 128; - constexpr ck_tile::index_t kBlockPerCu = 1; - std::cout << "block size " << kBlockSize << std::endl; - std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl; - std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N - << std::endl; - std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " " - << BlockWaves::at(ck_tile::number<1>{}) << std::endl; - std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl; - - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n, - warp_id)); - - std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; - - bool pass = true; - - if(do_validation) - { - // reference - y_buf.FromDevice(y_host_dev.mData.data()); - pass = ck_tile::check_err(y_host_dev, x_host); - - std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - const std::string data_type = arg_parser.get_str("prec"); - return run(arg_parser) ? 0 : -2; -} diff --git a/example/ck_tile/37_transpose/CMakeLists.txt b/example/ck_tile/37_transpose/CMakeLists.txt deleted file mode 100644 index d6f374a9b4..0000000000 --- a/example/ck_tile/37_transpose/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -set(TARGET_NAME tile_example_transpose) -add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp) -target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) - -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) - diff --git a/example/ck_tile/37_transpose/README.md b/example/ck_tile/37_transpose/README.md deleted file mode 100644 index 21578dd00e..0000000000 --- a/example/ck_tile/37_transpose/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Batched Transpose -This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution. - -## build -``` -# in the root of ck_tile -mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ -# Make the transpose executable -make tile_example_transpose -j -``` -This will result in an executable `build/bin/tile_example_transpose` - -## example -``` -args: - -N input batch size (default:2) - -C input channel size. (default:64) - -H input height size. (default:1) - -W input width size. (default:64) - -v whether do CPU validation or not (default: 1) - -layout_in input tensor data layout - NCHW by default - -layout_out output tensor data layout - NHWC by default - -seed seed to be used, -1 means random every time (default:-1) - -k_name t to 1 will print kernel name (default:0) -``` \ No newline at end of file diff --git a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp b/example/ck_tile/37_transpose/batched_transpose_kernel.hpp deleted file mode 100644 index 4681a12cf7..0000000000 --- a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/elementwise.hpp" -#include "ck_tile/host/hip_check_error.hpp" -#include -#include - -namespace ck_tile { - -struct BatchedTransposeHostArgs -{ - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - // index_t dim_blocks; - index_t dim_stride; - index_t dim_block_h; - index_t dim_block_w; -}; - -template -struct BatchedTransposeKernel -{ - using Pipeline = remove_cvref_t; - using Problem = remove_cvref_t; - - using Type = typename Problem::DataType; - - struct BatchedTransposeKargs - { - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - index_t dim_stride; - }; - - using Kargs = BatchedTransposeKargs; - using Hargs = BatchedTransposeHostArgs; - - CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) - { - size_t grid_size_x = h.dim_block_w; - size_t grid_size_y = h.dim_block_h; - size_t grid_size_z = h.batch; - return dim3(grid_size_x, grid_size_y, grid_size_z); - } - - CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) - { - Kargs k; - k.p_input = h.p_input; - k.p_output = h.p_output; - k.batch = h.batch; - k.height = h.height; - k.width = h.width; - k.dim_stride = h.dim_stride; - return k; - } - - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } - - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - __shared__ char smem[Pipeline::GetSmemSize()]; - static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock; - - const auto iDim = blockIdx.z; - const auto x_m_n = [&]() { - const auto x_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_input) + iDim * kargs.dim_stride, - make_tuple(kargs.height, kargs.width), - make_tuple(kargs.width, 1), - number{}, - number<1>{}); - - return pad_tensor_view(x_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock); - - const auto y_n_m = [&]() { - const auto y_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_output) + iDim * kargs.dim_stride, - make_tuple(kargs.width, kargs.height), - make_tuple(kargs.height, 1), - number{}, - number<1>{}); - - return pad_tensor_view(y_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto x_block_window = make_tile_window( - x_m_n, - make_tuple(number{}, number{}), - {static_cast(iM), static_cast(iN)}); - - auto y_block_window = make_tile_window( - y_n_m, - make_tuple(number{}, number{}), - {static_cast(iN), static_cast(iM)}); - - Pipeline{}(x_block_window, y_block_window, smem); - } -}; -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/block_transpose.hpp b/example/ck_tile/37_transpose/block_transpose.hpp deleted file mode 100644 index 5c0baab846..0000000000 --- a/example/ck_tile/37_transpose/block_transpose.hpp +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "transpose_policy.hpp" - -namespace ck_tile { - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kCol; - static constexpr index_t kSecondDim = kRow; -}; - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kRow; - static constexpr index_t kSecondDim = kCol; -}; - -// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the -// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the -// global memory -template // col number per xdl ops -struct TransposePipelineProblem -{ - static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_, - "the block size is not correct!"); - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kLeadNumWarps = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondNumWarps = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerBlock = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerBlock = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerXdl = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerXdl = - TransposeTraits::kSecondDim; - - static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; - static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; - - static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, - "block dim should be divided by warp dim!"); - static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, - "block dim should be divided by warp dim!"); - // how many rows/cols implemented in one warp - static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; - static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; - - static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - - // warp rows/cols is divided into xdl. - static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl; - static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl; - - static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, - "xdl dim should be divided by quad dim!"); - static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, - "xdl dim should be divided by quad dim!"); - // xdl rows/cols is divided into quadrants. - static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim; - static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim; - - static constexpr index_t kIterationsInSecondDim = - kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); -}; - -template -struct BlockTranspose -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; - static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; - - static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, - OutputTileWindow& output_window, - void* __restrict__ p_smem) - { - auto input_tile_window = - make_tile_window(input_window, Policy::template MakeInputDistribution()); - auto output_tile_window = - make_tile_window(output_window, Policy::template MakeOutputDistribution()); - - DataType* p_lds_ptr = static_cast(p_smem); - constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); - auto input_lds_block = - make_tensor_view(p_lds_ptr, in_lds_block_desc); - - constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); - auto output_lds_block = - make_tensor_view(p_lds_ptr, out_lds_block_desc); - - auto copy_to_lds_window = - make_tile_window(input_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - auto load_from_lds_window = - make_tile_window(output_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeLdsLoadTileDistribution()); - - auto x = load_tile(input_tile_window); - - store_tile(copy_to_lds_window, x); - block_sync_lds(); - - auto y = load_tile_transpose(load_from_lds_window); - - store_tile(output_tile_window, y); - } -}; - -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_api.cpp b/example/ck_tile/37_transpose/transpose_api.cpp deleted file mode 100644 index fe184b4023..0000000000 --- a/example/ck_tile/37_transpose/transpose_api.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "transpose_example.hpp" -#include - -template -float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) -{ - uint32_t dim_block_h = (a.height + block_y - 1) / block_y; - uint32_t dim_block_w = (a.width + block_x - 1) / block_x; - uint32_t dim_stride = a.height * a.width; - - a.dim_stride = dim_stride; - a.dim_block_h = dim_block_h; - a.dim_block_w = dim_block_w; - - using ts_problem = ck_tile::TransposePipelineProblem; - using ts_pipeline = ck_tile::BlockTranspose; - - using kernel = ck_tile::BatchedTransposeKernel; - - auto kargs = kernel::MakeKargs(a); - - const dim3 grids = kernel::GridSize(a); - constexpr dim3 blocks = kernel::BlockSize(); - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); - - return ave_time; -} - -float batched_transpose(batched_transpose_trait t, - batched_transpose_kargs a, - ck_tile::stream_config s) -{ - if(t.type == "fp16") - { - return batched_transpose_dispatch(a, s); - } - else if(t.type == "fp8") - { - return batched_transpose_dispatch(a, s); - } - - return -1; -} diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt new file mode 100644 index 0000000000..bdcb6f50bd --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -0,0 +1,13 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) + target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md new file mode 100644 index 0000000000..742a88dee7 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -0,0 +1,35 @@ +# GEMM Matrix Multiplication + +This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The aquant pipeline method on the gemm calculation +make tile_example_gemm_aquant_basic -j +``` +This will result in an executable `build/bin/tile_example_gemm_aquant_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp new file mode 100644 index 0000000000..2667cae788 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 256; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 32; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmAQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::AQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +#include "run_gemm_aquant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return 0; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4fp8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32bf8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp new file mode 100644 index 0000000000..35e80ddb89 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct GemmQuantTypeConfig +{ + using ADataType = ADataType_; + using QDataType = QDataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using QDataType = float; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("aq_layout", "R", "Aq tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc new file mode 100644 index 0000000000..9bdef9755b --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& aq_m_aqk_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t AQK, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::AQuantGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.QK = AQK; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.stride_AQ = stride_AQ; + + float ave_time = gemm_calc_aquant( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideC =" << stride_C << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name + << " A_Type = " << DataTypeTraits::name + << " AQ_Type = " << DataTypeTraits::name + << " B_Type = " << DataTypeTraits::name + << " Acc_Type = " << DataTypeTraits::name + << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = typename TypeConfig::ADataType; + using AQDataType = typename TypeConfig::QDataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be aligned with QuantGroupSize"); + } + + ck_tile::index_t AQK = K / QuantGroupSize; + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution fill_seed(0, 500); + + if(init_method == 0) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if(init_method == 1) + { + std::cout << "Monotonic initialization is not supported." << std::endl; + return 0; + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + } + else + { + a_m_k.SetZero(); + aq_m_aqk.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + if(!pass) + { + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; + return false; + } + + return pass; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8989060842..630b96ede0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -20,6 +20,6 @@ add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) +add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) -add_subdirectory(36_copy) -add_subdirectory(37_transpose) +add_subdirectory(38_block_scale_gemm) diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py index fdc0dcf5d7..b64fac7b06 100644 --- a/example/ck_tile/remod.py +++ b/example/ck_tile/remod.py @@ -13,7 +13,7 @@ for p in sorted(Path("./").rglob("*")): # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 0dfd275269..e6e3402e64 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -12,9 +12,8 @@ inline void hip_check_error(hipError_t x) if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " - << "hip_check_error.hpp" - << ": " << __LINE__ << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << "hip_check_error.hpp" << ": " + << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } diff --git a/include/ck/library/utility/algorithm.hpp b/include/ck/library/utility/algorithm.hpp index 57136f8a2a..185a147cce 100644 --- a/include/ck/library/utility/algorithm.hpp +++ b/include/ck/library/utility/algorithm.hpp @@ -11,10 +11,10 @@ namespace ck { namespace ranges { template -auto copy(InputRange&& range, OutputIterator iter) - -> decltype(std::copy(std::begin(std::forward(range)), - std::end(std::forward(range)), - iter)) +auto copy(InputRange&& range, + OutputIterator iter) -> decltype(std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter)) { return std::copy(std::begin(std::forward(range)), std::end(std::forward(range)), diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 4f421b4282..05357b1637 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -138,9 +138,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 33c918c997..fb8f6e79dc 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -202,7 +202,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() @@ -320,7 +320,7 @@ struct Tensor ~Tensor() = default; Tensor& operator=(const Tensor&) = default; - Tensor& operator=(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; template explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 3ffac32469..28974427d7 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -108,13 +108,13 @@ struct TensorAdaptor __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -338,8 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -361,8 +360,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return Number{}; }, @@ -384,8 +382,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -394,8 +391,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f1df2eedd4..a82f69fb3f 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -365,7 +365,7 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new visible dimension's hidden ids - constexpr auto unordered_new_visible_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_visible_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); - constexpr auto new_visible_dim_unordered2ordered = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, - NewUpperDimensionNewVisibleIdss{}); + constexpr auto new_visible_dim_unordered2ordered = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); constexpr auto new_visible_dim_hidden_ids = unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 9a326092d2..67da37cc90 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -94,10 +94,8 @@ struct SpaceFillingCurve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index c929956124..d0a594e2c6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -152,7 +152,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 14856f210c..6fb62bc677 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -91,6 +91,78 @@ struct BlockwiseGemmWmmaops_pipeline_base true> c_thread_buf_; + struct Empty + { + __device__ Empty() {}; + template + __device__ void GlobalLoad(bool cond) + { + ignore = NBuffer; + ignore = cond; + } + }; + + template + struct BScale + { + __device__ BScale(GridDesc b_scale_grid_desc_, + ThreadCopy b_scale_thread_copy_, + GridBuffer b_scale_grid_buf_) + : b_scale_thread_copy(b_scale_thread_copy_), + b_scale_grid_desc(b_scale_grid_desc_), + b_scale_grid_buf(b_scale_grid_buf_) {}; + + static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; + + static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + + static constexpr auto b_scale_thread_copy_step = + make_tuple(make_multi_index(NWaves * NPerWmma, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + + template + __device__ void GlobalLoad(bool cond) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, Number<0>{}), + b_scale_thread_bufs(Number{})); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(cond) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + } + + ThreadCopy b_scale_thread_copy; + GridDesc b_scale_grid_desc; + GridBuffer b_scale_grid_buf; + StaticallyIndexedArray{}> b_scale_thread_bufs; + }; + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } __device__ static auto GetWaveIdx() @@ -285,7 +357,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -296,7 +368,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index df82e155be..f25648efa6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -172,7 +175,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -186,6 +192,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -195,20 +203,42 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -258,6 +288,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -378,6 +409,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -421,7 +455,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -435,6 +472,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -445,30 +484,57 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - I0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - I0, - I0, - I0, - I0, - I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, + m0, + I0, + I0, + I0, + I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs(I0)[Number< + n0 * BScaleStruct::num_scale_k_block + + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } }); __builtin_amdgcn_sched_barrier(0); @@ -564,6 +630,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -613,7 +680,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -624,7 +691,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 5ceb8a6be4..8fed23d151 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + __device__ inline void LocalLoad(ABlockBuffer& a_block_buf, + AThreadBuffer& a_thread_buf, + BBlockBuffer& b_block_buf, + BThreadBuffer& b_thread_buf, + BScaleStruct& b_scale_struct) const + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + + if constexpr(ck::is_same_v) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + }); + } + template + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -283,7 +338,10 @@ struct BlockwiseGemmWmmaops_pipeline_v3( @@ -298,6 +356,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -314,20 +374,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); __builtin_amdgcn_sched_barrier(0); @@ -348,6 +396,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); + static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -392,22 +442,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 438d7d8ac3..231dbf817c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -96,9 +96,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_pipeline_v4 { static constexpr auto I0 = Number<0>{}; @@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -217,7 +217,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 9296b8136f..cd13dbb836 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -153,7 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -182,7 +182,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index 7e11304e2f..629bbb316f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle 2) + { + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - constexpr auto num_ds_read_a_mfma_perstage = - math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - constexpr auto num_ds_read_a_prefetch_stages = 2; + constexpr auto num_ds_read_a_prefetch_stages = 2; - constexpr auto buffer_load_perstage_more = - math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_less = - math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_stage2 = - math::integer_divide_floor((num_buffer_load_stage2), 2); + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); - constexpr auto buffer_load_stages_more = - num_buffer_load_stage1 - - math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * - ((num_total_stages - 2)); + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); - constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more; - constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less; - constexpr auto buffer_load_issue_point_interval_stage2 = - num_mfma_perstage / buffer_load_perstage_stage2; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; - // Stage 1 - // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + else + { + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + + num_buffer_load_b_scale; + constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( + num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < + mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } - }); - }); - - // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + else { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read } }); - }); - - // Stage 2, Sync - // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - }); + } } template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; @@ -138,7 +138,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d3f6344c27..e6bb2d8db3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -114,7 +114,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -667,9 +667,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -742,7 +742,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -771,7 +771,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 287c6701c3..84ee096cba 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -90,7 +90,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index 98cc149f4d..aa06f8c6c1 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -258,8 +258,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, is_src_valid); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -271,8 +270,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp index 3e9e501126..55dd924f8c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -281,8 +281,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, true); - constexpr auto move_src_on_dim = [&]() constexpr - { + constexpr auto move_src_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -295,11 +294,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); - constexpr auto move_dst_on_dim = [&]() constexpr - { + constexpr auto move_dst_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -311,8 +308,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 9285211519..c946abb77d 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -49,8 +49,8 @@ namespace device { #ifndef CK_CODE_GEN_RTC struct BaseArgument { - BaseArgument() = default; - BaseArgument(const BaseArgument&) = default; + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} @@ -60,8 +60,8 @@ struct BaseArgument struct BaseInvoker { - BaseInvoker() = default; - BaseInvoker(const BaseInvoker&) = default; + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) @@ -75,8 +75,8 @@ struct BaseInvoker struct BaseOperator { - BaseOperator() = default; - BaseOperator(const BaseOperator&) = default; + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 267a970ee5..52632785bd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -70,15 +70,9 @@ struct GroupedGemmKernelArgument for(auto sd : StrideDs) str << sd << ","; - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE + << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 72c011bfb2..1dd143f6a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -205,25 +205,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index fc1a2b995a..c57d5316ba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,25 +36,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 0cd1d84a43..c82da32313 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -58,21 +58,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 985752796b..efe8fe92c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -39,26 +39,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 12085edaae..811924a189 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -63,24 +63,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 1b487502f4..a38e0d25e7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -52,23 +52,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index d38698af4b..2ae4794d00 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -42,32 +42,32 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_gemm_xdl_cshuffle_v1( - const A0B0B1DataType* __restrict__ p_a0_grid, - const A0B0B1DataType* __restrict__ p_b0_grid, - D0sPointer p_d0s_grid, - const A0B0B1DataType* __restrict__ p_b1_grid, - D1sPointer p_d1s_grid, - E1DataType* __restrict__ p_e1_grid, - const A0ElementwiseOperation a0_element_op, - const B0ElementwiseOperation b0_element_op, - const CDE0ElementwiseOperation cde0_element_op, - const B1ElementwiseOperation b1_element_op, - const CDE1ElementwiseOperation cde1_element_op, - const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, - const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock, - const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2E1TileMap block_2_e1tile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_batched_gemm_gemm_xdl_cshuffle_v1( + const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + const A0ElementwiseOperation a0_element_op, + const B0ElementwiseOperation b0_element_op, + const CDE0ElementwiseOperation cde0_element_op, + const B1ElementwiseOperation b1_element_op, + const CDE1ElementwiseOperation cde1_element_op, + const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap block_2_e1tile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -829,10 +829,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle is_same_v && CheckDLayout() && (is_same_v || - is_same_v)&&CheckDLayout() && + is_same_v) && + CheckDLayout() && is_same_v)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 6624570b27..2e0b5da113 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -79,9 +79,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index de7d67f08b..851f6a5f97 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -39,26 +39,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp index 1026118381..2e1684adb6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -40,21 +40,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, - index_t N, - index_t K, - index_t O, - index_t G0, - index_t G1, - float alpha, - bool input_permute, - bool output_permute) + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) @@ -178,15 +178,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) @@ -310,17 +310,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, - const KVDataType* __restrict__ p_kv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t q_sequence_length, - index_t kv_sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index bae5c6019d..18b9e6ce74 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -43,30 +43,30 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - D0sPointer p_d0s_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const C0DEElementwiseOperation c0de_element_op, - const B1ElementwiseOperation b1_element_op, - const C1DEElementwiseOperation c1de_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sPointer p_d0s_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const C0DEElementwiseOperation c0de_element_op, + const B1ElementwiseOperation b1_element_op, + const C1DEElementwiseOperation c1de_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index e846b0630b..ec0fb7b98d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -42,27 +42,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index abd6574d8c..cecd312879 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -29,14 +29,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument - karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 494524b6f0..16d5feccf2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -48,9 +48,9 @@ namespace device { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 7d9555dc82..1419f5ee7c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -71,9 +71,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 8843e520a6..4934993693 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -610,8 +610,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(!parg) { std::ostringstream err; - err << "Provided argument pointer is not of an Argument class!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 9482812f75..dee3a51df7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -467,12 +467,12 @@ struct DeviceColumnToImageImpl float elapsed_time = 0.f; const auto kernel = kernel_tensor_rearrange, - GridwiseTensorRearrangeKernel>; + InputDataType, + OutputGridDesc, + OutputDataType, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch<>, + GridwiseTensorRearrangeKernel>; // Execute each set of independent filters for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index df5922a04f..b99032fb9f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -37,23 +37,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, - const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 77974f84ae..de8e524dc3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -35,23 +35,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 1b0db73fdd..dc07f8b445 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -35,17 +35,15 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3_for_conv3d( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t num_batches, - const index_t a_batch_stride, - const index_t b_batch_stride, - const index_t c_batch_stride, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index b9467ac194..9e8c959f98 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -34,21 +34,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index 47fb630ea9..8f4c41b69c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -37,31 +37,30 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EMeanVarDataType* __restrict__ p_e_grid, - EMeanVarDataType* __restrict__ p_welford_mean_grid, - EMeanVarDataType* __restrict__ p_welford_var_grid, - int32_t* __restrict__ p_welford_count_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock - mean_var_grid_desc_mblock_mperblock_nblock, - const CountGridDescriptor_MBlock_MPerBlock_NBlock - count_grid_desc_mblock_mperblock_nblock, - const Block2ETileMap block_2_etile_map, - index_t NRaw) + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap block_2_etile_map, + index_t NRaw) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; @@ -121,26 +120,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_welford_layernorm2d_second_half( - const EMeanVarDataType* __restrict__ p_e_grid, - const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, - const EMeanVarDataType* __restrict__ p_in_welford_var_grid, - const int32_t* __restrict__ p_in_welford_count_grid, - const GammaDataType* __restrict__ p_gamma_grid, - const BetaDataType* __restrict__ p_beta_grid, - HDataType* __restrict__ p_h_grid, - const EHGridDesc_M_N e_grid_desc_m_n, - const EHGridDesc_M_N h_grid_desc_m_n, - const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, - const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, - const GammaBetaGridDesc_N gamma_grid_desc_n, - const GammaBetaGridDesc_N beta_grid_desc_n, - index_t numMeanVarCountBlockTileIteration_N, - index_t NBlockClusterLength, - ComputeDataType epsilon, - HElementwiseOperation h_element_op) + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) { GridwiseWelfordLayernorm::Run(p_e_grid, p_in_welford_mean_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index c048e7249c..c1b3f98bc9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -38,27 +38,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - FloatRsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + FloatRsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index f193b093d1..e36816df64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -37,22 +37,22 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index 90afc467d4..a921962c67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -16,6 +16,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { @@ -229,222 +230,28 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - - auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - ck::utility::RotatingMemWrapper rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); + return DeviceGemmCommon::IsSupportedArgument(arg); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..1a68b35f1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale +{ + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3_BScale" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_Common +{ + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index 2554ffea46..0f6457f48e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -32,20 +32,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_waveletmodel_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const EElementwiseOperation e_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const EElementwiseOperation e_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 884175eaca..f32334cd91 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -28,14 +28,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_contraction_multiple_d_xdl_cshuffle( - const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_contraction_multiple_d_xdl_cshuffle( + const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 89a304fda4..fe9e4ff7e8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -74,24 +74,27 @@ template + InMemoryDataOperationEnum OutElementOp, + bool HasMainKBlockLoopInAllGemm, + bool NoMainKBlockLoopInAllGemm, + bool CTranspose> __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const std::array gemm_kernel_args, - const index_t gemms_count, - const AElementwiseOp a_element_op, - const BElementwiseOp b_element_op, - const CDEElementwiseOp cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t KBatch) + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const std::array gemm_kernel_args, + const index_t gemms_count, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group @@ -101,16 +104,21 @@ __global__ void const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); const long_index_t e_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -141,11 +149,11 @@ __global__ void group_id = index_t((left + right) / 2); } - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, + p_b_grid + b_batch_offset + b_n_offset, p_ds_grid_grp, p_e_grid + e_batch_offset + e_n_offset, p_shared, @@ -162,22 +170,44 @@ __global__ void } else { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } } #else ignore = p_a_grid; @@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // implementation we can avoid copy data to workspace before kernel launch since number of // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then // we run this kernel in the loop. - static constexpr index_t MaxGroupedGemmGroupsNum = 32; + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; @@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using ALayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::NHWGK, - std::conditional_t(), - tensor_layout::convolution::NDHWGK, - ALayout>>; - using BLayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::GKYXC, - std::conditional_t(), - tensor_layout::convolution::GKZYXC, - BLayout>>; - using ELayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::NHWGC, - std::conditional_t(), - tensor_layout::convolution::NDHWGC, - ELayout>>; + static constexpr bool isATensorColMajor = + (ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = + (NeedTransposeKernel == false) && (is_same_v || + is_same_v); + + using ALayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGK, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGK, + ALayout>>; + using BLayoutAfterTranspose = std::conditional_t< + is_NGCHW_GKCYX_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::GKYXC, + std::conditional_t() && + NeedTransposeKernel, + tensor_layout::convolution::GKZYXC, + BLayout>>; + using ELayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGC, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGC, + ELayout>>; using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + EDataType, + 1, + index_t, + CTranspose>; static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) @@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DLayout, true, /*SplitConvN*/ ABDataType, - DDataType>; + DDataType, + 1, /*index_t NumGroupsToMerge = 1,*/ + index_t, /* typename IndexType = */ + CTranspose>; return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); }, Number{}); const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); - - return make_tuple( - a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + if constexpr(CTranspose) + { + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + else + { + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } } // GridwiseGemm @@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + +#define GridwiseGemmCTransposeTemplateParameters \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmCTranspose = std::conditional_t< + CTranspose, + GridwiseGemmMultipleD_xdl_cshuffle, + GridwiseGemm>; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); } template @@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map - using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using Block2ETileMap = + decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; @@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 sizeof(EDataType); std::array a_g_n_k_wos_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, - a_g_n_k_wos_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides) + : a_g_n_k_wos_strides; std::array b_g_k_c_xs_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides; std::array e_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths, - e_g_n_c_wis_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides) + : e_g_n_c_wis_strides; // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_N_per_block_ = conv_to_gemm_transform_.N_; - const auto a_grid_desc_ak0_m_ak1 = - conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); - - const auto b_grid_desc_bk0_n_bk1 = - conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + const auto a_grid_desc_ak0_m_ak1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + else + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + }(); + const auto b_grid_desc_bk0_n_bk1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + else + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + }(); DsGridDesc_M_N ds_grid_desc_m_n; // populate Ds desc @@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DLayout, true, /*SplitConvN*/ ABDataType, - DDataType>; + DDataType, + 1, + index_t, + CTranspose>; ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ a_g_n_k_wos_lengths, a_g_n_k_wos_strides_transposed, @@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto GemmK = a_grid_desc_m_k.GetLength(I1); const bool HasMainKBlockLoop = - GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_); + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_); gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = GemmArgs{a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - GridwiseGemm:: + GridwiseGemmCTranspose:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n), MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { // Use not modified base strides a_in_transpose_desc_ = @@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceATensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t a_acum = ck::accumulate_n( a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceBTensorSizeBytes() const { - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t b_acum = ck::accumulate_n( b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceETensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t e_accum = ck::accumulate_n( e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; EDataType* p_e_grid = arg.p_e_grid_; - - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } } - for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); gemm_set_id++) { @@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } }; - auto launch_kernel = [&]() { - const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - MaxGroupedGemmGroupsNum, - GemmArgs, - AElementwiseOp, - BElementwiseOp, - CDEElementwiseOp, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - ElementOp>; + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } - return launch_and_time_kernel_with_preprocess(stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - arg.p_ds_grid_, - p_e_grid, - gemm_kernel_args, - gemms_count_for_set, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - arg.k_batch_); + auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemmCTranspose, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + BElementwiseOp, + AElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_b_grid, + p_a_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } }; - - ave_time += launch_kernel(); + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } } return ave_time; @@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { arg.Print(); } + // Transpose from NGKHW to NHWGK - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + @@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // Transpose from NHWGC to NGCHW - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( @@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } - const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; - const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; - const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; - + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); // Specifialization if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) @@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 if constexpr(is_same_v || is_same_v || is_same_v || - is_same_v || - is_same_v || - is_same_v) + is_same_v || NeedTransposeKernel) { if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) { return false; } } + else if(is_same_v || + is_same_v) + { + static_assert(NeedTransposeKernel == false); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } else { return false; @@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 is_same_v || is_same_v) { - // vector load D matrix from global memory - if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + if(CTranspose == false) { - ds_valid = false; + // vector load D matrix from global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + ds_valid = false; + } } } else @@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 is_same_v || is_same_v) { - // vector store C matrix into global memory - if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + if(CTranspose == false) { - return false; + // vector store C matrix into global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } } } else @@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // Gridwise GEMM size for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity( + if(!GridwiseGemmCTranspose::CheckValidity( arg.a_grid_desc_m_k_container_[i], arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], @@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 0b3f1a0255..3306e311b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -35,18 +35,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_dlops_bwd_weight( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, - const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_dlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index a819b91b05..e5872816f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -77,21 +77,21 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl using CElementwiseGridDesc = remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>; using GridwiseElementwiseCast = GridwiseElementwise, - Tuple, - Tuple, - Tuple, - Block2TileMapElementwise, - WeiElementwiseOperation, - ElementwiseBlockSize, - I1, - ElemsPerBlock, - I1, - ElemsPerBlock / ElementwiseBlockSize, - Sequence<0, 1>, - Sequence<1>, - Sequence<1>, - I1, - I1>; + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation, + ElementwiseBlockSize, + I1, + ElemsPerBlock, + I1, + ElemsPerBlock / ElementwiseBlockSize, + Sequence<0, 1>, + Sequence<1>, + Sequence<1>, + I1, + I1>; struct Argument : public BaseArgument { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 672c7dd2f7..601bf4eb5a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -43,22 +43,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c7c463f43d..8796f5520e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -44,16 +44,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); @@ -99,16 +99,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 6c53161ded..6f6a3587ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -41,22 +41,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f13a256d6b..bbaa04536c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -42,16 +42,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); @@ -100,16 +100,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 3e14f66a09..e7446bb995 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -72,23 +72,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 50e171e503..393ee80881 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -93,18 +93,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6d2988ba24..ac40d363b5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -81,25 +81,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfG compute_ptr_offset_of_groups, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) @@ -383,11 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGC, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGC, - ALay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGC, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -403,11 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::GKYXC, - std::conditional_t() && NeedTransposeKernel, - ctc::GKZYXC, - BLay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::GKYXC, + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -423,11 +423,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGK, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGK, - ELay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGK, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 48424c16b9..a938820e6c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -21,7 +21,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -61,37 +61,46 @@ namespace { * */ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const long_index_t a_batch_offset = + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -101,56 +110,77 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const long_index_t a_batch_offset = + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -163,22 +193,33 @@ __global__ void __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } @@ -277,10 +318,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; static constexpr bool isMultiD = DsDataType::Size() > 0; - static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD; static constexpr bool DoElementwiseBeforeCShuffle = - !isMultiABD && is_same_v && + !isMultiD && is_same_v && !is_same_v; static constexpr index_t NumATensor = GetNumABTensors(); @@ -294,12 +335,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + // Generate vector size for C & Ds + using CDEBlockTransferScalarPerVectors = + typename uniform_sequence_gen::type; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; + static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -321,11 +369,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGC, - std::conditional_t(), - ctc::NDHWGC, - ALay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -351,11 +399,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::GKYXC, - std::conditional_t(), - ctc::GKZYXC, - BLay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::GKYXC, + std::conditional_t(), + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -381,11 +429,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGK, - std::conditional_t(), - ctc::NDHWGK, - ELay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -396,30 +444,81 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return out_gemmm_gemmn_desc; } + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + // desc for problem definition constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; - -#define GridwiseGemmV3TemplateParams \ - tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ - tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ - EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ - MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle + using DsGridDesc_M_N = + remove_cvref_t; // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + DsLayout, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeDataType, + BComputeDataType, + ADataType, + BDataType, + DoElementwiseBeforeCShuffle>; + + // #undef GridwiseGemmV3TemplateParams using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -493,37 +592,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 I0, I1>; - static auto - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) - { - const index_t M = e_grid_desc_m_n.GetLength(I0); - const index_t N = e_grid_desc_m_n.GetLength(I1); - return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); - } - // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // Argument struct Argument : public BaseArgument { Argument(const void* p_as, const void* p_bs, - const std::array&, + const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>&, - const std::array, NumDTensor>&, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, @@ -535,6 +624,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const CDEElementwiseOperation& cde_element_op) : p_a_grid_{}, p_b_grid_{}, + p_ds_grid_{p_ds}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( @@ -542,6 +632,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, @@ -561,13 +653,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 input_left_pads_, input_right_pads_}, conv_N_per_block_{conv_to_gemm_transformer_.N_}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, b_grid_desc_bk0_n_bk1_{ MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, - e_grid_desc_m_n_{ - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -583,12 +675,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_a_grid_ = static_cast(p_as); p_b_grid_ = static_cast(p_bs); + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -610,14 +723,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); - elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ - b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; e_out_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } @@ -680,6 +793,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; } @@ -687,6 +802,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // pointers (tuple if multi AB, pointer if no) const ADataType* p_a_grid_; const BDataType* p_b_grid_; + const std::array p_ds_grid_; EDataType* p_e_grid_; // for checking IsSupportedArgument() @@ -694,6 +810,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::array a_g_n_c_wis_strides_; std::array b_g_k_c_xs_lengths_; std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; std::array e_g_n_k_wos_lengths_; std::array e_g_n_k_wos_strides_; std::array conv_filter_strides_; @@ -705,18 +823,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 index_t num_group_; ConvToGemmFwdTransformer conv_to_gemm_transformer_; - index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - EGridDesc_M_N e_grid_desc_m_n_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffset compute_ptr_offset_of_groups_; + ComputePtrOffset compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -759,6 +877,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; index_t gdx, gdy, gdz; + // TODO: Do we want to support kbatch ?? std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); @@ -784,20 +903,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 sizeof(EDataType); } - typename GridwiseGemm::Argument gemm_arg{p_a_grid, - p_b_grid, - p_e_grid, - GemmM, - GemmN, - GemmK, - I0, - I0, - I0, - I1, - false, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + GemmM, + GemmN, + GemmK, + // No need to set strides, we pass descs to kernel + I0, + I0, + {}, + I0, + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) @@ -827,24 +949,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 gemm_arg_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, arg.compute_ptr_offset_of_groups_, arg.compute_ptr_offset_of_n_); } else { - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; @@ -854,15 +977,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } // Tail number could be One to Seven @@ -870,30 +994,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } @@ -903,10 +1029,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -921,10 +1048,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -939,10 +1067,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -957,10 +1086,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -975,10 +1105,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -993,10 +1124,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1012,10 +1144,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1026,10 +1159,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1041,48 +1175,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } } + // has_main_k_block_loop else { // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } @@ -1095,6 +1233,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 float avg_time = 0.f; if constexpr(!isMultiABD) { + // Transpose to NGHWC layotu if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1147,6 +1286,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 avg_time += RunGemm(arg, stream_config); + // Transpose result back to NGCHW if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1205,6 +1345,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(isMultiABD) { return false; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } } // check device @@ -1213,12 +1358,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // FIXME: re-enable fp64 when SWDEV-335738 is fixed if constexpr(!(is_same_v || is_same_v)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "On gfx908 the accumulation data type must be one of fp32 or int32!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } if(!ck::is_xdl_supported()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Current device does not support xdl instructions!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1236,6 +1394,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input paramters do not align with specialization " + "Filter1x1Stride1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1252,6 +1417,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "The input paramters do not align with specialization Filter1x1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1268,11 +1440,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[A Layout] The number of input channels is not a multiple of " + "ABlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1288,11 +1472,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[B Layout] The number of input channels is not a multiple of " + "BBlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1301,11 +1497,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * C is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1316,11 +1526,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The input_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The output_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1340,6 +1564,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] One of the transposed vectors is exceeding 2GB " + "memory size!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1354,17 +1585,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[E Layout] The K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } // Gridwise gemm v3 doesn't verify descriptors size if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1374,8 +1624,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + typename GridwiseGemm::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; return GridwiseGemm::CheckValidity(gemm_arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ec1a05366e..1e5c67aac7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -131,29 +131,29 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batch_gemm_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - RsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batch_gemm_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + RsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 9988367959..b1494a36bf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -41,16 +41,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( - Array gemm_desc_kernel_args, - const index_t gemms_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 21afc06040..7cfc73fab6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -36,14 +36,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const index_t grid_size_grp, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 10d8a4a44d..d0d613af8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -32,13 +32,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 18872e38ea..7b5dd55a8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -576,16 +576,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(dev_gemm_workspace == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -624,16 +624,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(arg.p_dev_gemm_kargs_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(arg.p_workspace_ == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -711,8 +711,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(not_all_have_kbatch_value_same) { std::ostringstream err; - err << "Not all gemms have same kbatch value (=1 or >1)! " - << "group [" << i << "], kbatch: " << gemm_arg.k_batch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << gemm_arg.k_batch << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 61058dec2b..38bb19b712 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -60,13 +60,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) @@ -600,8 +600,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -629,8 +629,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(arg.p_dev_gemm_args_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 3fb2c5ae86..1754b542c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -32,16 +32,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( - const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index cbee4e09f4..a528149ecd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -31,13 +31,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 8fe71fb9a2..81134465af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -38,17 +38,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - uint32_t* barrier_count, - const index_t barrier_size_grp, - const index_t group_count, - const index_t grid_size_grp, - const index_t KBatch, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 01f52881f4..ea14087698 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -33,13 +33,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); @@ -416,8 +416,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK1)! " - << "group [" << i << "], kbatch: " << kbatch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << kbatch << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp index 67a100a112..b66ab997bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -45,21 +45,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 08d177035e..27d3c378ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -325,12 +325,50 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; - RunKernel(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); } } #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 48a10f219c..efa85a357c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -100,64 +100,64 @@ struct DeviceMoeGemmBlockScale { static constexpr index_t NumDTensor = DsDataType::Size(); using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp index 6dc3a5f881..4bf38d9d1f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp @@ -92,62 +92,62 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index cc88c1a104..e196ed5e3a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -44,21 +44,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 63b49d9aa0..c1d3aa43de 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,25 +36,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, - const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 9fe2f0d976..cc500bb9cb 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -33,7 +33,7 @@ struct MaskDisabledPredicate }; __host__ __device__ constexpr bool - IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const { return false; } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 34c76b89e4..d86f01e255 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -379,10 +379,10 @@ struct AddClamp __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const { - const half_t a = x0 + x1; - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const half_t floor = type_convert(floor_); + const half_t ceil = type_convert(ceil_); + const half_t a = x0 + x1; + y = a > floor ? (a < ceil ? a : ceil) : floor; }; template <> diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 8f829496da..4a87e8a277 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -266,7 +266,7 @@ struct DequantPack8 dst.template AsType()(Number<3>{}) = type_convert(src.template AsType()[Number<3>{}]); - y = dst.template AsType()[Number<0>{}]; + y = dst.template AsType()[Number<0>{}]; #endif } diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 02dba97430..36dc8aa6ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -527,11 +527,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -997,9 +997,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) { const auto c_ds_src_data_refs = concat_tuple_of_reference( tie(e_thread_buf[i]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); auto e_dst_data_refs = tie(e_thread_buf(i)); unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); }); @@ -1124,7 +1123,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle }); } // shuffle C + Ds + welford + write out - } // run + } // run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index e3c50ef06c..cc3306e1bd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -228,9 +228,8 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d static_for<0, MThreadSliceSize, 1>{}([&](auto I) { const auto c_ds_buf_refs = concat_tuple_of_reference( tie(accu_value_buf[I]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, + Number{})); unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 53a45c7f16..e8f8caa10d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -372,11 +372,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle : false; constexpr auto is_scale_mfma = false; constexpr auto mfma = MfmaSelector::selected_mfma; + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( @@ -669,11 +669,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_A0K1_B0K1, MfmaSelector::selected_mfma.k_per_blk); + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, @@ -1176,18 +1176,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_desc_refs = concat_tuple_of_reference( tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_buf_refs = concat_tuple_of_reference( tie(c1_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c1_d1s_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 1326c5d62d..839a68a978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -24,14 +24,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op) + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) { GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, out_grid_desc_tuple, @@ -56,20 +56,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size) + kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size) { if(get_block_1d_id() < a_grid_size) { @@ -112,27 +112,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_batched_dual( - const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size, - const index_t batch_count_a, - const index_t batch_count_b, - const std::array input_batch_strides_a, - const std::array input_batch_strides_b, - const std::array output_batch_strides_a, - const std::array output_batch_strides_b) + kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size, + const index_t batch_count_a, + const index_t batch_count_b, + const std::array input_batch_strides_a, + const std::array input_batch_strides_b, + const std::array output_batch_strides_a, + const std::array output_batch_strides_b) { static_assert(InAGridDescTuple::Size() == NumInputsA && InADataTypePointerTuple::Size() == NumInputsA); @@ -217,17 +216,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op, - const index_t batch_count, - const std::array input_batch_strides, - const std::array output_batch_strides) + kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op, + const index_t batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) { static_assert(InGridDescTuple::Size() == NumInputs && InDataTypePointerTuple::Size() == NumInputs); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 21dac6f9e9..fab0fbab1d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -34,21 +34,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - const ScaleDataType* __restrict__ p_scale_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const ScaleGridDesc scale_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index f406bfb95a..6e73f0955b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -40,31 +40,31 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC0* __restrict__ p_bias_grid, - const FloatC1* __restrict__ p_d0_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const C1ElementwiseOperation c1_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c0_grid_desc_mblock_mperblock_nblock_nperblock, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 562b9b8ffa..5e779b2881 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -28,15 +28,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp index b473d7cbf2..7deda48f7b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -21,12 +21,12 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -154,17 +154,10 @@ struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 054aca2936..c37ffb6263 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -687,11 +687,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle static constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -863,18 +863,16 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 127d889572..df5c8b10f3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -952,7 +952,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 }); // copy c, d, e + reduction } // shuffle C + Ds + reduction + write out - } // Run + } // Run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index de6c9c1601..36eb4489e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -34,25 +34,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc, - const BGridDesc_BK0_N_BK1 b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group @@ -127,25 +127,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_etile_map) + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); @@ -219,23 +219,22 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_mupltipe_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index acbccf1889..318ff59383 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -657,11 +657,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -856,18 +856,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 1e79d67f93..769bc5b877 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -38,23 +38,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -73,18 +73,18 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, block_2_etile_map); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; #endif } @@ -814,18 +814,16 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // A tuple of reference to C/Ds tensor descriptors. const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // A tuple of reference to C/Ds grid buffers. const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // A tuple of starting index of C/Ds blockwise copy. const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 5815eb5b0b..85b5b5faab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -611,11 +611,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + AComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -855,18 +855,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index db227bb7ef..b257fa4aa3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -35,24 +35,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 70301c326a..b4848c7077 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -593,11 +593,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -769,18 +769,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1032,11 +1030,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index f64838ea4e..1b4c2666ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -607,11 +607,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -845,18 +845,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4458b9356d..51cd5ada91 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -31,19 +31,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f3354cd5dd..9a8d09e5e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -14,47 +14,10 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - /// @brief \"Universal\" GEMM kernel with SplitK support. /// /// @par Overview @@ -207,391 +170,143 @@ template struct GridwiseGemm_wmma_cshuffle_v3 + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - WmmaSelector::selected_wmma - .k_per_wmma); + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) - { - return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); - } - - __host__ static auto CalculateMPadded(index_t M) - { - return math::integer_least_multiple(M, MPerBlock); - } - - __host__ static auto CalculateNPadded(index_t N) - { - return math::integer_least_multiple(N, NPerBlock); - } - - __host__ static auto CalculateKPadded(index_t K) - { - return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; - } - - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); - } - - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); - } - - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * KPerBlock; - } - - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = K_Batch * KReadVec; - return (K + K_t - 1) / K_t * KReadVec; - } - - __host__ static auto CalculateMBlock(index_t M) - { - return math::integer_divide_ceil(M, MPerBlock); - } - - __host__ static auto CalculateNBlock(index_t N) - { - return math::integer_divide_ceil(N, NPerBlock); - } - - template - __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) - { - // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 - constexpr auto K0 = BlockDesc{}.GetLength(I0); - constexpr auto K1 = BlockDesc{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto KRow = I2; -#else - constexpr auto KRow = I1; -#endif - return transform_tensor_descriptor( - BlockDesc{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - } - - __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( - index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_right_pad_transform(M, MPad - M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - static_assert(!PermuteA, "PermuteA is not supported"); - - // not pad M or K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( - index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - static_assert(!(is_same_v, pk_i4_t> && - GemmSpec != GemmSpecialization::Default), - "pk_i4_t does not support padding"); - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - if constexpr(!PermuteB) - { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // Pre-shuffled Weight - // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] - constexpr index_t BK01 = KPerBlock / BK1Value; - const index_t BK0_ = StrideB / BK1Value; - const index_t BK00 = BK0_ / BK01; - - const auto b_grid_desc_bk00_n_bk01_bk1_permute = - make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); - - const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( - b_grid_desc_bk00_n_bk01_bk1_permute, - make_tuple(make_merge_transform(make_tuple(BK00, BK01)), - make_pass_through_transform(make_tuple(N)), - make_pass_through_transform(BK1Value)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_grid_desc_bk0_n_bk1_permute; - } - } - } - - template - __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - - return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); - } - - template - __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) - { - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); - } - - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - // TODO: Investigate why this path is not used in the original - // gridwise_gemm_xdl_cshuffle_v3.hpp -#if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } -#endif - } + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; struct Problem { @@ -622,20 +337,11 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -749,943 +455,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerWmma; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerWmma; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __host__ __device__ static constexpr auto - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; - } - - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - ComputeTypeB, - AccDataType, - decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ static constexpr bool CheckValidity(const Argument& karg) - { - static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && - (NPerBlock % (NPerWmma * NRepeat)) == 0, - "Invalid tuning param!"); - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - !(is_same::value)) - { - if(!(karg.M % MPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - (is_same::value)) - { - if(!(karg.N % NPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) - { - - auto K_t = karg.KBatch * KPerBlock; - if(!(karg.K % K_t == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = karg.KBatch * KReadVec; - auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; - if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) - { - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.K % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.M % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.K % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - else - { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) - { - if(!karg.IsReduceAdd()) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; - } - } - } - - // check gridwise gemm pipeline - const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); - - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) - { - if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) - { - return false; - } - } - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); - } - - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); - } - - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) - { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), - make_unmerge_transform(make_tuple(NBlock, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - return c_grid_desc_mblock_mperblock_nblock_nperblock; - } + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared, - const Problem& problem, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // Cast after lds - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - // C mapping in single thread. - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm_pipeline - .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // C mapping in single block - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = - blockwise_gemm_pipeline - .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I1); - constexpr auto MSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I2); - constexpr auto NWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I4); - constexpr auto NThreadPerSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I5); - constexpr auto MAccVgprs = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I6); - - // LDS descriptor, shuffle and write out in MRepeat x NRepeat times - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize()); - - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // MRepeat per shuffle repeat - MWave, // MWave - MSubGroup, // MSubGroup * MAccVgprs = MPerWmma - MAccVgprs)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // NRepeat per shuffle repeat - NWave, // NWave - NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 1, 2, 6>{}, - Sequence<>{}, - Sequence<3, 4, 5>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - MRepeat, MWave, MSubGroup, MAccVgprs))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor - .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - NRepeat, NWave, NThreadPerSubGroup))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor - .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6>, - 6, - 1, // vector write pixel - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - make_multi_index(0, - m_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - 0, - n_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for local reg & global memory - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } + __device__ static index_t GetKBlockPerScale() { return 1; } template (p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - problem, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..37ffbf1c51 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { + +template +struct GridwiseGemm_wmma_cshuffle_v3_b_scale + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> +{ + using BScaleType = ck::half_t; + + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + + using ThisThreadBlock = ThisThreadBlock; + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const BScaleType* p_b_scale_grid, + index_t block_n_id) + { + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + + static constexpr auto ScaleSliceSizeN = NRepeat; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; + auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template BScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } + + __device__ static index_t GetKBlockPerScale() + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct + auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // NOTE: Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..fc01866ddf --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,1420 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_wmma_cshuffle_v3_base +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t& block_m_id, + const index_t& block_n_id, + const index_t& num_k_block_per_scale, + BScaleStruct& b_scale_struct) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 63d40f6ff8..68112489ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -217,20 +217,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index d45ed79ae3..9089bd2ce2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -54,9 +54,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write @@ -538,24 +538,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << ", " - << "Stream-K Selection:" << Streamk_sel << ", " - << "Grid size:" << Grid_size << ", " - << "Reduction Strategy:" + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", " << "Stream-K Selection:" << Streamk_sel + << ", " << "Grid size:" << Grid_size << ", " << "Reduction Strategy:" << (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic" : "Reduction") << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7edcd7270f..c22229a183 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -20,9 +20,9 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -42,12 +42,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -436,20 +436,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -822,11 +813,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index f92268265f..48c577b2e0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -20,7 +20,7 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) @@ -46,12 +46,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif - kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, - const FloatB* p_b_grid, - FloatC* p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -475,20 +475,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -881,11 +872,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 6270d0c4dc..5f3950b29e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -58,7 +58,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -666,20 +666,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 8d5c844103..91f08413af 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) @@ -58,7 +58,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) @@ -155,11 +155,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr bool is_single_rate_mfma = true; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); @@ -575,20 +575,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 93c1779a80..d8c697823a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -60,7 +60,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -563,22 +563,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 97d0e2a4eb..9f442906f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -29,7 +29,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -59,7 +59,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -589,18 +589,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead + << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 + << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" + << std::endl; } index_t M; @@ -1757,18 +1750,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2340,18 +2331,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index c8dbd81b73..17b4cd7c68 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) @@ -65,7 +65,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) @@ -143,7 +143,8 @@ template + typename LDSTypeB = BDataType, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -466,6 +467,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } + else + { + static_assert(false, + "The layout configuration is not supported! " + "Only support Row & Col major."); + } }(); // pad M and N @@ -538,8 +545,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Number{}); } - using DsGridDesc_M_N = remove_cvref_t; - struct Problem { __host__ __device__ Problem() = default; @@ -572,20 +577,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1245,11 +1241,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1273,11 +1269,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1288,17 +1284,62 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1515,43 +1556,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -1566,18 +1627,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1601,7 +1660,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -1625,7 +1686,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = @@ -1698,12 +1759,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1729,12 +1790,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1745,8 +1806,53 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + Run_2Lds(p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1982,43 +2088,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -2033,18 +2159,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2068,7 +2192,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -2092,7 +2218,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 64fbda7a44..b41f1220fb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -538,20 +538,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1556,18 +1547,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 3553a1d040..27926e5290 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) @@ -65,7 +65,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) @@ -174,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle : false; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KGroup = []() { if constexpr(is_same_v, f8_t>) @@ -599,20 +599,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1414,18 +1405,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1855,18 +1844,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 909376e5f7..20711f0c5e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( @@ -66,7 +66,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( @@ -555,20 +555,11 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1446,18 +1437,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1948,18 +1937,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index ca3902188e..bc87559c43 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -422,8 +422,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(!((is_same_v, f6x16_pk_t> || is_same_v, bf6x16_pk_t> || is_same_v, f6x32_pk_t> || - is_same_v, bf6x32_pk_t>)&&GemmSpec != - GemmSpecialization::Default), + is_same_v, bf6x32_pk_t>) && + GemmSpec != GemmSpecialization::Default), "Packed F6 types do not support padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -648,23 +648,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 6691c63484..7902a16fb3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -674,23 +674,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 67fb4d651e..80ce6a1bc4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -36,26 +36,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_layernorm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, // MxN - const FloatC0* __restrict__ p_c0_bias_grid, // 1xN - const FloatC0* __restrict__ p_c0_add_grid, // MxN - const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN - const FloatC0* __restrict__ p_c0_beta_grid, // 1xN - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index b7947309e4..697d0f90d9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -152,19 +152,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -182,16 +182,16 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -752,11 +752,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + MPerXDL, + NPerXDL, + FloatBAdjusted, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_skip_b_lds_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_skip_b_lds_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 3e23008a5f..0c5f8de1e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -30,13 +30,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); @@ -168,17 +168,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index e9190dee29..104632d3f0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -23,19 +23,19 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, - const typename GridwiseGemm::FloatAB* p_b_grid, - typename GridwiseGemm::FloatC* p_c_grid, - void* p_workspace, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - typename GridwiseGemm::Block2CTileMap block_mapping) + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -174,13 +174,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 5c3d9b7ba4..dc9429ea6e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -26,17 +26,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M_N c_grid_desc_m_n) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -50,24 +50,24 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m_n; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -90,7 +90,7 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = karg; + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -200,16 +200,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "K0:" << K0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0 + << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 7d8e94c001..978f08ad4a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -29,18 +29,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, - const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 256b495c6e..a546b471bf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -28,13 +28,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -175,17 +175,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 15c2da9d32..66a3fef4eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -31,20 +31,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index e22bfb6439..eb4e7d3db3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -31,23 +31,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 3da5e66018..5bd5f75fa9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -32,26 +32,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 36f8fd7cc1..ca68fe9f86 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -40,7 +40,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -75,7 +75,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -619,22 +619,12 @@ struct GridwiseMoeGemm __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1112,7 +1102,7 @@ struct GridwiseMoeGemm } // check gridwise gemm pipeline -#if 1 +#if 0 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) @@ -1714,18 +1704,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1746,40 +1734,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2436,18 +2424,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2468,40 +2454,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index f092c9c1eb..7145efbd97 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -40,7 +40,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -77,7 +77,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -626,22 +626,12 @@ struct GridwiseMoeGemmBlockScale __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1764,18 +1754,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1796,40 +1784,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2506,18 +2494,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2538,40 +2524,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 5f8e524fb2..6731a7dda6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -81,7 +81,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -678,25 +678,14 @@ struct GridwiseMoeGemmMX __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -2769,18 +2758,16 @@ struct GridwiseMoeGemmMX // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2801,41 +2788,41 @@ struct GridwiseMoeGemmMX const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 9ccd334262..d8d77ae388 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -42,7 +42,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -205,11 +205,11 @@ struct GridwiseMoeGemmMXBNS static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; using mfma_selector = MfmaSelector; + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>; static constexpr index_t KPack = math::max( math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); @@ -611,25 +611,14 @@ struct GridwiseMoeGemmMXBNS __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -1956,18 +1945,16 @@ struct GridwiseMoeGemmMXBNS // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1988,41 +1975,41 @@ struct GridwiseMoeGemmMXBNS const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 156db6e636..7c3dbceeaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -42,7 +42,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -79,7 +79,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -708,25 +708,14 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -1474,7 +1463,7 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1567,7 +1556,7 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, @@ -2185,7 +2174,7 @@ struct GridwiseMoeGemmMX_BPreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2289,7 +2278,7 @@ struct GridwiseMoeGemmMX_BPreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); const auto b_scale_grid_buf_up = make_dynamic_buffer( @@ -2588,18 +2577,16 @@ struct GridwiseMoeGemmMX_BPreshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2620,41 +2607,41 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index 61d0f9e0d5..fa9b5fb2ce 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -86,7 +86,7 @@ struct GridwisePermute ~Block2TileMap() = default; Block2TileMap& operator=(const Block2TileMap&) = delete; - Block2TileMap& operator=(Block2TileMap&&) = delete; + Block2TileMap& operator=(Block2TileMap&&) = delete; explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index ddf0b4a58d..bffc3c696c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -25,15 +25,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_tensor_rearrange(const InputGridDesc in_grid_desc, - const InputDataType* __restrict__ p_in_global, - const OutputGridDesc out_grid_desc, - OutputDataType* __restrict__ p_out_global, - const index_t batch_count, - const Block2ETileMap block_2_tile_map, - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp index 8a0e16d7f6..e399499cc8 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -399,7 +399,7 @@ struct GridwiseNormalizationBwdData_mk_to_mk dx_grid_desc_m_k, dx_global_val_buf); - } // end of sweep once + } // end of sweep once else // Sweep Twice pipeline { constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 4e4c92de40..2305997f70 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -823,8 +823,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_(Number{}) = src_tmp_vector.template AsType()[i]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -837,8 +836,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -983,8 +981,7 @@ struct ThreadwiseTensorSliceTransfer_v3 is_dst_valid, dst_tmp_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -997,8 +994,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 79e22018a6..4a6ed62c0e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -246,22 +246,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using dst_elem_op_vec_t = typename vector_type::type; using VectorSizeLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; using VectorOffsetsLookupTable = Tuple, Sequence, Sequence, @@ -308,8 +308,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -322,8 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -636,8 +634,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -650,8 +647,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 174b82f870..8af6a2148b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant .template SetAsType( src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -243,8 +242,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -376,8 +374,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant scale_thread_scratch_.template SetAsType( scale_data_idx_seq, scale_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -391,8 +388,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move scale coord static_for<0, nDim, 1>{}([&](auto i) { @@ -666,8 +662,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -680,8 +675,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 50f1e21beb..8574fd055c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -277,8 +277,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - auto move_on_dim = [&]() constexpr - { + auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -292,8 +291,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { if(move_on_dim[i]) @@ -603,8 +601,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -617,8 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index f0d793456d..9383e3f829 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -245,8 +244,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -438,8 +436,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -454,8 +451,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 40ebdeff08..4e9c188115 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -198,8 +198,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 src_vector.template AsType()[Number{}]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -212,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -368,8 +366,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -382,8 +379,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 9b1ff3dbf8..65e63993a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -421,8 +421,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); - auto forward_step_scatter = [&]() constexpr - { + auto forward_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { @@ -430,8 +429,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); static_for<0, nDst, 1>{}([&](auto i) { move_tensor_coordinate( dst_descs[i], @@ -493,8 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - auto reset_step_scatter = [&]() constexpr - { + auto reset_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { step_(i) = @@ -502,8 +499,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); return reset_step_scatter; } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index a191c75099..977c622f06 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -30,7 +30,8 @@ template < typename ADataType = float, typename CDataType = float, index_t NumGroupsToMerge = 1, - typename IndexType = index_t> + typename IndexType = index_t, + bool CTranspose = false> struct TransformConvBwdDataToGemm_v1 { private: @@ -555,6 +556,41 @@ struct TransformConvBwdDataToGemm_v1 return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_)); } } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_ * Wo_, K_), make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_ * Ho_ * Wo_, K_), + make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } else { throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); @@ -608,7 +644,9 @@ struct TransformConvBwdDataToGemm_v1 (is_same_v || is_same_v || is_same_v || - is_same_v), + is_same_v || + is_same_v || + is_same_v), bool>::type = false> __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const { @@ -848,16 +886,16 @@ struct TransformConvBwdDataToGemm_v1 } } - template || - is_same_v), - bool>::type = false> + template < + typename BLayout_ = BLayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const { - // assume packed - // k_y_x_c for 2d or k_z_y_x_c for 3d - const auto wei_grid_desc = MakeWeiGridDesc(); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -886,6 +924,12 @@ struct TransformConvBwdDataToGemm_v1 } else { + // assume packed + // k_y_x_c for 2d or k_z_y_x_c for 3d + static_assert(is_same_v || + is_same_v); + const auto wei_grid_desc = MakeWeiGridDesc(); + // GemmK is different for each GEMM const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); @@ -1059,6 +1103,7 @@ struct TransformConvBwdDataToGemm_v1 bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { + static_assert(CTranspose == false); // assume strided // n_hi_wi_c for 2d n_di_hi_wi_c for 3d const auto in_grid_desc = MakeInGridDesc(); @@ -1314,6 +1359,48 @@ struct TransformConvBwdDataToGemm_v1 } } + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const auto in_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, C_, Di_ * Hi_ * Wi_), make_tuple(NStrideTensorC_, CStrideTensorC_, I1)); + + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + if constexpr(CTranspose) + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(C_), + make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmNPerBlock, GemmMPerBlock), + Sequence{}); + } + else + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + } + } // for input bias template (reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1426,7 +1426,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1503,7 +1503,7 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); #endif // #ifndef CK_CODE_GEN_RTC @@ -1704,7 +1704,7 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), @@ -1734,7 +1734,7 @@ using bf8_t = bf8_ocp_t; #define CK_FP8_TYPE_FNUZ 0 #define CK_FP8_TYPE_OCP 1 #else -using f8_t = f8_fnuz_t; +using f8_t = f8_fnuz_t; using bf8_t = bf8_fnuz_t; #define CK_FP8_TYPE_FNUZ 1 #define CK_FP8_TYPE_OCP 0 diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 8646b8393b..02a7a72b8c 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index bd0ca42ecd..d6524283db 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -19,7 +19,7 @@ __host__ __device__ constexpr auto container_push_back(const Array { Array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r(Number{}) = x; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8f5a45bdf0..5fbe30d21b 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -50,7 +50,7 @@ struct f4x2_pk_t __host__ __device__ inline type unpack(Number) const { static_assert(I < 2, "Index is out of range."); - if constexpr(I == 0) + if constexpr(I == 1) return (data >> 4); else return data & 0b00001111; @@ -58,7 +58,7 @@ struct f4x2_pk_t __host__ __device__ inline type pack(const type x0, const type x1) { - return (x0 << 4) | (x1 & 0b00001111); + return (x1 << 4) | (x0 & 0b00001111); } // Compare operator diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index ed42b22daf..027290dbf8 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -232,7 +232,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_LOAD bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else - bool constexpr use_amd_buffer_addressing = false; + bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 46ba32bb87..2f5b804d16 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace ck { namespace internal { diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index a700fcfff1..8cb37b68b2 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -25,8 +25,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 7b079c541c..993b70a3fb 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -75,7 +75,7 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); @@ -88,7 +88,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); @@ -97,7 +97,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); @@ -107,21 +107,21 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { return CalculateMagicNumbers(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { return CalculateMagicMultiplier(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { return CalculateMagicShift(integral_constant{}); } diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 90a018fe3a..7de84d974c 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -377,10 +377,7 @@ inline __host__ __device__ float2_t scaled_type_convert(e8m0_b f4x2_t f4x2_array[4]; } value{}; value.f4x2_array[0] = x; - float2_t tmp = - __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); - // permute high bits and low bits to match the order of the original vector - return float2_t{tmp[1], tmp[0]}; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); #else float2_t ret{utils::to_float( scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), @@ -406,10 +403,9 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m float f_scale = type_convert(scale); ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); - // permute high bits and low bits to match the order of the original vector - ret[2 * idx] = op[1]; - ret[2 * idx + 1] = op[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); + ret[2 * idx] = op[0]; + ret[2 * idx + 1] = op[1]; }); return ret; diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 497625f7e2..75f0c92c58 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -342,8 +342,8 @@ struct sequence_reverse using seq_split = sequence_split; using type = typename sequence_merge< - typename sequence_reverse::type, - typename sequence_reverse::type>::type; + typename sequence_reverse::type, + typename sequence_reverse::type>::type; }; template diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index d6b6eac26c..7652e73809 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load() { #ifdef __gfx12__ asm volatile("\ - s_wait_vmcnt 0x0 \n \ + s_wait_loadcnt 0x0 \n \ s_wait_dscnt 0x0 \n \ s_barrier_signal -1 \n \ s_barrier_wait -1 \ diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 05e461fa63..99538ac78c 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -39,6 +39,19 @@ namespace details { } // namespace details } // namespace +#if defined(__gfx950__) +inline __device__ bhalf_t static_cast_float_to_bf16(float x) +{ + union + { + uint16_t uint16; + __bf16 bf16; + } out; + out.bf16 = static_cast<__bf16>(x); + return out.uint16; +} +#endif + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); @@ -47,6 +60,9 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x); template <> inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) { +#if defined(__gfx950__) + return static_cast_float_to_bf16(x); +#else // Nan check if(x != x) { @@ -63,6 +79,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); +#endif } // convert fp16 to bfp16 via fp32 with RTN if higher precision is needed @@ -242,7 +259,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -310,7 +327,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -1401,8 +1418,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - // permute high bits and low bits to match the order of the original vector - value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); return value.f4x2_array[0]; #else union @@ -1410,8 +1426,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type(x[1] / scale); - uint8_t h = utils::sat_convert_to_type(x[0] / scale); + uint8_t l = utils::sat_convert_to_type(x[0] / scale); + uint8_t h = utils::sat_convert_to_type(x[1] / scale); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1429,9 +1445,8 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 } f4_values{}, tmp_values{}; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - // permute high bits and low bits to match the order of the original vector tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( - tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0); + tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0); f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0]; }); @@ -1480,7 +1495,7 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -1500,14 +1515,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - // permute high bits and low bits to match the order of the original vector - value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0); + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); return value.f4x2_array[0]; #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1516,8 +1529,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); - uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t l = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[1] / scale, rng); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1544,20 +1557,15 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f float_values.floatx32_array = x; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - // permute high bits and low bits to match the order of the original vector f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - f4_values.bitwise, - float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]}, - rng, - scale, - 0); + f4_values.bitwise, float_values.floatx2_array[idx], rng, scale, 0); }); return f4_values.f4x32_array; #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1648,9 +1656,7 @@ inline __host__ __device__ float2_t type_convert(f4x2_t x) } value{}; value.f4x2_array[0] = x; float scale = 1.0f; - float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); - // permute high bits and low bits to match the order of the original vector - return float2_t{tmp[1], tmp[0]}; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); #else float2_t ret{ utils::to_float(NumericLimits::Binary_1(), @@ -1676,10 +1682,9 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) float scale = 1.0f; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); - // permute high bits and low bits to match the order of the original vector - ret[2 * idx] = op[1]; - ret[2 * idx + 1] = op[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); + ret[2 * idx] = op[0]; + ret[2 * idx + 1] = op[1]; }); return ret; @@ -1812,7 +1817,7 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -2150,7 +2155,7 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 8dabb58451..26cfcaa2f0 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -407,17 +407,17 @@ struct Tensor ElementSpaceSize, true /*InvalidElementUseNumericalZeroValue*/>; using StaticBufferType = std::conditional_t< - is_scalar_type::value, - StaticBuffer, - StaticBufferTupleOfVector>::vector_size, - scalar_type>::vector_size, - true /*InvalidElementUseNumericalZeroValue*/>>; + is_scalar_type::value, + StaticBuffer, + StaticBufferTupleOfVector>::vector_size, + scalar_type>::vector_size, + true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 10dfdd7d28..188cebaabc 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -66,6 +66,7 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/debug.hpp" #include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index aaa7db2574..f7f9489f4c 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1259,7 +1259,7 @@ struct slice : public base_transform<1, 1> printf("}"); } // namespace ck -}; // namespace ck +}; // namespace ck /* * \brief lower_idx = upper_idx % modulus. diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp index 6591acddb9..648a1251be 100644 --- a/include/ck_tile/core/algorithm/space_filling_curve.hpp +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -100,10 +100,8 @@ struct space_filling_curve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index aafc6c0a85..0932f39ca7 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); if constexpr(pre_nop) asm volatile("s_nop 4\n" @@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -624,7 +624,7 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -681,7 +681,7 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -709,7 +709,7 @@ struct buffer_store_if<2> { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -737,7 +737,7 @@ struct buffer_store_if<1> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -1783,60 +1783,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template (in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } - else if constexpr(std::is_same_v, ck_tile::fp8_t>) + else if constexpr(std::is_same_v, ck_tile::fp8_t> || + std::is_same_v, ck_tile::bf8_t> || + std::is_same_v, ck_tile::int8_t>) { - typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; - __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index ca4ff8ca7e..ce4af430e2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1553,60 +1553,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template (in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } - else if constexpr(std::is_same_v, ck_tile::fp8_t>) + else if constexpr(std::is_same_v, ck_tile::fp8_t> || + std::is_same_v, ck_tile::bf8_t> || + std::is_same_v, ck_tile::int8_t>) { - typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; - __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } diff --git a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp index 7ffe6dc0fb..665be1b167 100644 --- a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp +++ b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp @@ -10,53 +10,55 @@ namespace ck_tile { // this generate wave level tile distribution -template +template struct LaneGroupTransposeTraits; -template -struct LaneGroupTransposeTraits> +template +struct LaneGroupTransposeTraits> { + static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64, + "LaneGroupSize must be 16, 32, or 64"); // before transpose, 4x16 static constexpr index_t ksecondDim = 4; - static constexpr index_t kleadDim = 16; + static constexpr index_t kleadDim = LaneGroupSize; // after transpose, 16x4 - static constexpr index_t ksecondDimT = 16; + static constexpr index_t ksecondDimT = LaneGroupSize; static constexpr index_t kleadDimT = 4; template - using TileDistribution = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 1, 2>, - sequence<1, 1, 3>>; + using TileDistribution = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 4>>; }; -template -struct LaneGroupTransposeTraits> +template +struct LaneGroupTransposeTraits> { static constexpr index_t ksecondDim = 8; - static constexpr index_t kleadDim = 16; + static constexpr index_t kleadDim = LaneGroupSize; - static constexpr index_t ksecondDimT = 16; + static constexpr index_t ksecondDimT = LaneGroupSize; static constexpr index_t kleadDimT = 8; template - using TileDistribution = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 1, 2>, - sequence<1, 1, 3>>; + using TileDistribution = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 4>>; }; /* @@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits> * consecutive. */ template CK_TILE_DEVICE constexpr auto make_transposed_distr_encode() { - using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits:: - template TileDistribution; - return xdllevel_dstr_encoding{}; + return typename LaneGroupTransposeTraits:: + template TileDistribution{}; } } // namespace ck_tile diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 3dd9604b01..0723026836 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -10,6 +10,15 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 +#define CK_TILE_VMCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \ + ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10)) +#define CK_TILE_EXPCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4)) +#define CK_TILE_LGKMCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8)) + namespace ck_tile { template @@ -113,13 +122,72 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) #endif } +// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html +struct waitcnt_arg +{ + // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 + // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV + CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; + + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111; + + template + CK_TILE_DEVICE static constexpr index_t from_vmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); + return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10)); + } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]"); + return MAX & (cnt << 4); + } + + template + CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); + return MAX & (cnt << 8); + } +}; + +template +CK_TILE_DEVICE void s_waitcnt() +{ + __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt() | + waitcnt_arg::from_expcnt() | + waitcnt_arg::from_lgkmcnt()); +} + +template +CK_TILE_DEVICE void s_waitcnt_barrier() +{ + s_waitcnt(); + __builtin_amdgcn_s_barrier(); +} + CK_TILE_DEVICE void block_sync_lds_direct_load() { +#if 1 + // invoke clang builtins which *should* produce the same result as the inline asm below + // difference: inline asm is being compiled to wait vmcnt(0) after the barrier + s_waitcnt_barrier<0, waitcnt_arg::kMaxExpCnt, 0>(); +#else + // same content as in old CK (#999) asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } CK_TILE_DEVICE void s_nop(index_t cnt = 0) diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index 474eda80d1..1a631bd95e 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -16,7 +16,7 @@ template CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array& a, const TData& x) { array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r[number{}] = x; return r; } diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index b187b71830..94309dd5dd 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1236,9 +1236,8 @@ constexpr auto reverse_slice_sequence(Seq, template ::type> -constexpr auto slice_sequence(Seq, - number, - Mask = typename uniform_sequence_gen::type{}) +constexpr auto +slice_sequence(Seq, number, Mask = typename uniform_sequence_gen::type{}) { constexpr auto r = reverse_slice_sequence(Seq{}.reverse(), number{}, Mask{}.reverse()); diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 3700d348e7..63d145d8b9 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -262,12 +262,18 @@ struct tuple : impl::tuple_base, T...> return flag; } + CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; } + #define TP_COM_() static_assert(I < size(), "wrong! out of range") // clang-format off - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const & { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) & { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv(std::move(*this)); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) && { TP_COM_(); return std::move(*this).template get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv(std::move(*this)); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const &&{ TP_COM_(); return std::move(*this).template get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv(*this); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) const { TP_COM_(); return get(); } @@ -470,6 +476,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence) return make_tuple(f(x.at(number{}), y.at(number{}), z.at(number{}))...); } +template +constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence) +{ + return std::forward(f)(std::forward(t).get(number{})...); +} + } // namespace detail template @@ -493,6 +505,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } +template +constexpr decltype(auto) apply(F&& f, Tuple&& t) +{ + constexpr index_t N = std::decay_t::size(); + return detail::apply_impl(std::forward(f), std::forward(t), make_index_sequence{}); +} + namespace detail { template diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index b5da468319..a3ce614f84 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t #if CK_TILE_USE_OCP_FP8 static constexpr int bias = 7; // OCP #else - static constexpr int bias = 8; // FNUZ + static constexpr int bias = 8; // FNUZ #endif using raw_type = uint8_t; raw_type data; diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 8176fe551c..b8a31ba8fc 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -31,8 +31,8 @@ struct scales CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {} template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const - -> decltype(std::declval() * rhs) + CK_TILE_HOST_DEVICE constexpr auto + operator()(const Right& rhs) const -> decltype(std::declval() * rhs) { return lhs_ * rhs; } @@ -43,13 +43,13 @@ struct scales /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ scales(Scale)->scales; +__host__ __device__ scales(Scale) -> scales; template struct plus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } @@ -59,21 +59,21 @@ template <> struct plus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ plus()->plus; +__host__ __device__ plus() -> plus; template struct minus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } @@ -83,21 +83,21 @@ template <> struct minus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ minus()->minus; +__host__ __device__ minus() -> minus; template struct multiplies { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } @@ -107,15 +107,15 @@ template <> struct multiplies { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ multiplies()->multiplies; +__host__ __device__ multiplies() -> multiplies; template struct maximize @@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) template struct equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } @@ -338,15 +338,15 @@ template <> struct equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ equal()->equal; +__host__ __device__ equal() -> equal; template <> struct equal @@ -369,8 +369,8 @@ struct equal template struct less { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } @@ -380,21 +380,21 @@ template <> struct less { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less()->less; +__host__ __device__ less() -> less; template struct less_equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } @@ -404,15 +404,15 @@ template <> struct less_equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less_equal()->less_equal; +__host__ __device__ less_equal() -> less_equal; template <> struct less_equal diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 541093e337..ba8b87a9b8 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0); + float x_h = ((x_u8 & 0xf0) >> 4); + + x_l = x_l > 7 ? x_l - 16 : x_l; + x_h = x_l > 7 ? x_l - 16 : x_l; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + fp32x2_t res = {x_h, x_l}; +#elif + fp32x2_t res = {x_l, x_h}; +#endif + return res; +} + CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 5cae332007..13b038bc48 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -994,51 +994,34 @@ struct buffer_view" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix + // clang-format off static_assert( - (std::is_same_v, int8_t> && - std::is_same_v, int8_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x2_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x4_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x8_t>) || - (std::is_same_v, int8_t> && - std::is_same_v, int8x16_t>) || - (std::is_same_v, int8x4_t> && - std::is_same_v, int8x4_t>) || - (std::is_same_v, int8x8_t> && - std::is_same_v, int8x8_t>) || - (std::is_same_v, int8x16_t> && - std::is_same_v, int8x16_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x2_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x4_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x8_t>) || + (std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || + (std::is_same_v, int8x4_t> && std::is_same_v, int8x4_t>) || + (std::is_same_v, int8x8_t> && std::is_same_v, int8x8_t>) || + (std::is_same_v, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, int8_t> && - std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || // ext_vector_type for pk_int4 must use int8_t as type - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x4_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x8_t> && - std::is_same_v, thread_buffer>) || - (std::is_same_v, pk_int4x16_t> && - std::is_same_v, thread_buffer>), + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x4_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x16_t> && std::is_same_v, thread_buffer>), "wrong! not implemented for this combination, please add " "implementation"); + // clang-format on if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8_t>) || @@ -1090,6 +1073,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8x16_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index d178ccb72c..1535250722 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -17,6 +17,11 @@ namespace ck_tile { +constexpr int DS_READ_TR_SIZE() +{ + return 8; // Literal constant, evaluated at compile time +} + namespace util { template struct is_sequence_suffix @@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix::valu template struct DefaultTranspose { + template struct Quad16 { - using InputEncoding = tile_distribution_encoding, - tuple, sequence<4, 4>>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, + "LaneGroupSize must be 64, 32, or 16"); + using InputEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; - using OutputEncoding = tile_distribution_encoding, - tuple, sequence<4>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + using OutputEncoding = + tile_distribution_encoding, + tuple, sequence<4>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; + template struct Quad8 { - using InputEncoding = tile_distribution_encoding, - tuple, sequence<2, 8>>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, + "LaneGroupSize must be 64, 32, or 16"); + using InputEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; - using OutputEncoding = tile_distribution_encoding, - tuple, sequence<8>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + using OutputEncoding = + tile_distribution_encoding, + tuple, sequence<8>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; // Select based on data size + template using QuadInputEncoding = std::conditional_t; + typename Quad16::InputEncoding, + typename Quad8::InputEncoding>; + template using QuadOutputEncoding = std::conditional_t; + typename Quad16::OutputEncoding, + typename Quad8::OutputEncoding>; // Always swap last two dimensions static constexpr auto transpose_dims = sequence<1, 0>{}; @@ -96,51 +113,79 @@ struct DefaultTranspose return idx; // Identity mapping }; - template - struct ValidationTraits + template + struct ValidationTraitsImpl { - static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; - static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_; + using QuadEncoding = std::conditional_t, + QuadInputEncoding>; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto input_hs = InDstrEncode::hs_lengthss_; + static constexpr auto quad_hs = QuadEncoding::hs_lengthss_; // 1. Must be 2D tensor static constexpr bool dims_valid = (InDstrEncode::NDimX == 2); // 2. Quad pattern must be suffix of input pattern static constexpr bool suffix_valid_dim0 = - util::is_sequence_suffix_v()), - decltype(input_hs_lengthss.template get<0>())>; + util::is_sequence_suffix_v; static constexpr bool suffix_valid_dim1 = - util::is_sequence_suffix_v()), - decltype(input_hs_lengthss.template get<1>())>; + util::is_sequence_suffix_v; // 3. PS→RHS mapping constraints - static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; - static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; + static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_; + static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_; - static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1; - static constexpr index_t ndimp_inner = - input_ps_to_rhss_major[number{}].size() - 1; + static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0]; + + static constexpr auto input_ps_major_last = + input_ps_major[number{}]; + static constexpr auto input_ps_minor_last = + input_ps_minor[number{}]; + + using psys_offset = ck_tile::sequence; + static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2( + [](auto i) { + return number{}; + }, + number{}); static constexpr bool ps_mapping_valid = - (input_ps_to_rhss_major[number{}][number{}] == 2) && - (input_ps_to_rhss_minor[number{}][number{}] == - input_hs_lengthss[number<1>{}].size() - 2) && - (input_ps_to_rhss_major[number{}][number{}] == 1) && - (input_ps_to_rhss_minor[number{}][number{}] == - input_hs_lengthss[number<0>{}].size() - 1); + util::is_sequence_suffix_v && + util::is_sequence_suffix_v; // 4. YS→RHS mapping constraints - static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; - static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; + static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_; + static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_; + static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_; + static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_; + static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1, + "YS->RHS mapping must be single dimension"); + static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1, + "YS->RHS mapping must be the last dimension"); static constexpr bool ys_mapping_valid = - (input_ys_to_rhs_major.back() == 2) && - (input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) && - (input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) && - (input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] == - input_hs_lengthss[number<0>{}].size() - 2); + (input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1); static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 && ps_mapping_valid && ys_mapping_valid; }; + + template + struct ValidationTraits + { + static constexpr bool value = + ValidationTraitsImpl::value || + ValidationTraitsImpl::value || + ValidationTraitsImpl::value; + static constexpr index_t LaneGroupSize = + ValidationTraitsImpl::value ? 64 + : ValidationTraitsImpl::value ? 32 + : ValidationTraitsImpl::value ? 16 + : 0; + }; }; template struct TransposeTileDistrChecker @@ -154,111 +199,150 @@ struct TransposeTileDistrChecker // this is used to generate the transposed output tile distribution encoding // based on the input tile distribution encoding -template > -struct OutputTileDistributionTraits + typename Policy = DefaultTranspose, + bool ReverseDirection = false> +struct TransposeTileDistributionTraits { - using InDstrEncode = typename remove_cvref_t::DstrEncode; - static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; - static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_; - static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_; + using InDstrEncode = remove_cvref_t; + static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; + static constexpr index_t LaneGroupSize = + Policy::template ValidationTraits::LaneGroupSize; + static_assert(Policy::template ValidationTraits::value, + "The input tile distribution encoding is not valid for transpose!"); + + using QuadInputEncoding = std::conditional_t< // + ReverseDirection, + typename Policy::template QuadOutputEncoding, + typename Policy::template QuadInputEncoding>; + using QuadOutputEncoding = std::conditional_t< // + ReverseDirection, + typename Policy::template QuadInputEncoding, + typename Policy::template QuadOutputEncoding>; + + static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_; + static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_; static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; - static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_; - static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_; + static constexpr auto I0 = number<0>{}; + static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0]; + static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0]; + static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0]; + static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_; + static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_; + + static constexpr index_t dim0 = Policy::transpose_dims[0]; + static constexpr index_t dim1 = Policy::transpose_dims[1]; + + static constexpr auto swap_one_and_two = [](const index_t idx) { + return (idx == 1) ? 2 : (idx == 2) ? 1 : idx; + }; // for transpose load - // append the reversed quad output hs lengths to the input hs lengthss after removing - // the quad_input_hs_lengthss - // then reverse the whole sequence to get the dst_out_hs_lengthss - static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss); - - static constexpr auto full_out_hs_lengthss = generate_tuple( + // remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse + // dims and append the quad_output_hs_lengthss to the end of each dimension + static constexpr auto outer_hs_lengthss = generate_tuple( [](auto i) { - return input_hs_lengthss[i] - .extract(typename arithmetic_sequence_gen<0, - input_hs_lengthss[i].size() - - quad_input_hs_lengthss[i].size(), - 1>::type{}) - .push_back(reversed_quad_output_hs_lengthss[i]); + constexpr auto input_i = input_hs_lengthss[i]; + constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size(); + return typename sequence_split::left_type{}; + }, + number{}); + static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss); + static constexpr auto dst_out_hs_lengthss = generate_tuple( + [](auto i) { + auto outer_i = reversed_outer_hs_lengthss[i]; + // append the reversed quad output hs lengths to the outer hs lengths + return outer_i.push_back(quad_output_hs_lengthss[i]); }, number{}); - static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss); - - // for PS→RHS mapping(both major and minor), we need to modify the last element of the major - // sequence - static constexpr auto modified_ps_to_rhss_major = generate_tuple( + // for PS→RHS mapping(both major and minor), we need to modify the last element (which is for + // thread distr) of the major sequence + static constexpr auto dst_ps_to_rhss_major = generate_tuple( + // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed [](auto i) { if constexpr(i == input_ps_to_rhss_major.size() - 1) { constexpr auto current_size = input_ps_to_rhss_major[i].size(); - constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size(); + constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size(); + constexpr auto quad_out = quad_output_ps_to_rhss_major0; constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract( typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); - return reduced_ps_to_rhss_major.push_back(number<2>{}); + return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out); } else { - // For all other sequences, keep them unchanged - return input_ps_to_rhss_major[i]; + // For all other sequences (i.e. warp), keep them unchanged + return input_ps_to_rhss_major[i].transform(swap_one_and_two); } }, number{}); - static constexpr auto minor_last_index = - full_out_hs_lengthss[number{}].size() - 1; - static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1; + static constexpr auto quad_idx_offset = + transform_tuples([](auto x) { return number{}; }, reversed_outer_hs_lengthss); + + // minus 1 because RsLength is not counted + static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for( + [](auto x) { return quad_idx_offset[number{}]; }, quad_output_ps_to_rhss_major0)); + static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for( + [](auto x) { return quad_idx_offset[number{}]; }, quad_output_ys_to_rhs_major)); static constexpr auto dst_ps_to_rhss_minor = generate_tuple( [](auto i) { + constexpr auto input_i = input_ps_to_rhss_minor[i]; if constexpr(i == input_ps_to_rhss_minor.size() - 1) { - constexpr auto current_size = input_ps_to_rhss_minor[i].size(); - constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size(); - constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract( - typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); - return reduced_ps_to_rhss_minor.push_back(number{}); + constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size(); + constexpr auto outer_ps = + typename sequence_split::left_type{}; + + return outer_ps.push_back(quad_output_ps_minor_offset + + quad_output_ps_to_rhss_minor0); } else { // For all other sequences, keep them unchanged - return input_ps_to_rhss_minor[i]; + return input_i; } }, number{}); + static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back(); + // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed - static constexpr auto swap_one_and_two = [](const index_t idx) { - return (idx == 1) ? 2 : (idx == 2) ? 1 : idx; - }; - static constexpr auto dst_ps_to_rhss_major = generate_tuple( - [](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); }, - number{}); + static constexpr auto dst_ys_to_rhs_major = + outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{}); - static constexpr auto modified_input_ys_to_rhs_major = - input_ys_to_rhs_major.pop_back().push_back(number<1>{}); + static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back( + number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{}); - static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2( - [](auto i) { return number{}; }, - number{}); - - static constexpr auto dst_ys_to_rhs_minor = - input_ys_to_rhs_minor.pop_back().push_back(number{}); - - using OutDstrEncode = tile_distribution_encoding, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t>; + using TransposedDstrEncode = + tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>; }; +template > +using OutputTileDistributionTraits = + TransposeTileDistributionTraits; +template > +using InputTileDistributionTraits = + TransposeTileDistributionTraits; + template , typename = std::enable_if_t::distr_encoding_valid, - Policy>> + typename BottomTensorView_::DataType, + Policy>::distr_encoding_valid, + Policy>> CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution& tile_window) { - using OutTileDstrEncode = - typename OutputTileDistributionTraits::OutDstrEncode; + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose(); diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp index f82f6b5bcd..6ee1fa54f4 100644 --- a/include/ck_tile/core/tensor/sweep_tile.hpp +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -303,6 +303,6 @@ struct tile_sweeper template ::type> -CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper; +CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 6bcba4019c..e2a6ae6555 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -81,7 +81,7 @@ struct tensor_adaptor template CK_TILE_HOST_DEVICE static constexpr auto - get_transform_and_its_upper_dimension(number) + get_transform_and_its_upper_dimension(number) { // FIXME: length of bottom dimension is not known, since info about lower dim length are not // saved in transformation @@ -119,13 +119,13 @@ struct tensor_adaptor CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -461,7 +461,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor, sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new top dimension's hidden ids - constexpr auto unordered_new_top_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_top_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); constexpr auto new_top_dim_unordered2ordered = unpack( [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{}); @@ -595,8 +595,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_lower_dimension_hidden_idss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -619,8 +618,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return number{}; }, @@ -643,8 +641,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_upper_dimension_hidden_idss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -653,8 +650,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index d7be5957c6..11e6b35c39 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -202,7 +202,7 @@ struct tile_distribution // FIXME: it's hacky to get Y index from Distributed-Index template CK_TILE_HOST_DEVICE static constexpr auto - get_y_indices_from_distributed_indices(DistributedIndices) + get_y_indices_from_distributed_indices(DistributedIndices) { constexpr auto ys_idx_arr = [] { array ys_idx; @@ -266,7 +266,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t // this returns a constexpr encoding of tile_distribution template CK_TILE_HOST_DEVICE constexpr auto - make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) +make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) { using RsLengths = typename StaticTileDistributionEncoding_::RsLengths; using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss; @@ -614,8 +614,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( constexpr auto src_y_maps = src_y_info[number<1>{}]; constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; - constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr - { + constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr { auto y_slice_sorted_origins = make_zero_multi_index(); auto y_slice_lengths = Encoding::detail::ys_lengths_; constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks(); @@ -685,8 +684,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps); return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths); - } - (); + }(); constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}]; constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}]; diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index d2b24ad54e..284efd5d70 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -327,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) template CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) { - if constexpr((std::is_same_v || - std::is_same_v)&&std::is_same_v && + if constexpr((std::is_same_v || std::is_same_v) && + std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { return impl::cast_tile_pk_fp8_fp32(src_tensor); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index c4b24fba93..b5a89e5f51 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -74,8 +74,9 @@ struct tile_window_linear static constexpr auto get_num_non_linear_access() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear = [&]() { index_t cnt = 1; @@ -109,8 +110,9 @@ struct tile_window_linear static constexpr auto get_non_linear_access_map() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear_map = [&]() { array m_{0}; index_t cumulative_len_ = 1; @@ -244,8 +246,9 @@ struct tile_window_linear { using SFC_Ys = typename Base::Traits::SFC_Ys; constexpr auto idx_ys = SFC_Ys::get_index(number{}); - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto modified_idx_ys = generate_tuple( [&](auto i_dim_y) { diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp new file mode 100644 index 0000000000..15f0718dc2 --- /dev/null +++ b/include/ck_tile/core/utility/debug.hpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} + +template +struct str_literal +{ + static constexpr const char data[] = {Xs..., '\0'}; + static constexpr const size_t size = sizeof...(Xs); + + template + CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal /*rhs*/) const + { + return str_literal{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal sep) + { + if constexpr(N == 0) + return str_literal<>{}; + else if constexpr(N == 1) + return str_literal{}; + else + return duplicate_n(sep) + str_literal{}; + } +}; + +#define make_str_literal(lit_) \ + std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \ + makeTuple(std::make_index_sequence())) + +template +constexpr std::tuple...> +makeTuple(std::index_sequence) noexcept +{ + return {}; +} +constexpr size_t constexpr_strlen(const char* c) +{ + size_t t = 0; + while(*c++) + ++t; + return t; +} + +template +struct static_distributed_tensor; + +template +struct thread_buffer; + +// Usage example: CK_PRINTF{}(tensor); +template , + typename PREFIX = str_literal<>, + typename SUFFIX = str_literal<>> +struct CK_PRINTF; +template +struct CK_PRINTF, + str_literal, + str_literal> +{ + template + CK_TILE_HOST_DEVICE static constexpr auto default_format() + { + if constexpr(std::is_same_v) + return make_str_literal("%8.3f"); + else if constexpr(std::is_same_v) + return make_str_literal("%5d"); + else if constexpr(std::is_same_v) + return make_str_literal("%5u"); + else + return make_str_literal("0x%08x"); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_prefix() + { + constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] "); + if constexpr(sizeof...(PREFIXChars) == 0) + return fmt_tid; + else + return fmt_tid + make_str_literal(" ") + str_literal{}; + } + CK_TILE_HOST_DEVICE static constexpr auto get_suffix() + { + constexpr auto lf = make_str_literal("\n"); + if constexpr(sizeof...(SUFFIXChars) == 0) + return lf; + else + return str_literal{} + lf; + } + + template + CK_TILE_HOST_DEVICE void impl(const thread_buffer& buf, + std::integer_sequence) const + { + using FMT1 = std::conditional_t()), + str_literal>; + constexpr auto fmt_v = FMT1::template duplicate_n(make_str_literal(" ")); + constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix(); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" + printf(fmt_wrap_v.data, get_thread_id(), N, type_convert(buf[Is])...); +#pragma clang diagnostic pop + } + + template + CK_TILE_HOST_DEVICE void operator()(const thread_buffer& buf) const + { + using ConvertTo_ = std::conditional_t, T, ConvertTo>; + impl(buf, std::make_integer_sequence{}); + } + + template + CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor& tensor) const + { + return operator()(tensor.get_thread_buffer()); + } +}; + +template , + typename PREFIX = str_literal<>, + typename SUFFIX = str_literal<>> +struct CK_PRINTF_WARP0 : public CK_PRINTF +{ + using base_t = CK_PRINTF; + + template + CK_TILE_HOST_DEVICE void operator()(const T& buf) const + { + if(get_thread_id() < get_warp_size()) + base_t::operator()(buf); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 95fb1bd834..c43a64edaa 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -58,8 +58,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index ed3b464660..6bd6e33bd3 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -49,7 +49,7 @@ struct composes /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ composes(Ts&&...)->composes...>; +__host__ __device__ composes(Ts&&...) -> composes...>; template struct saturates @@ -57,8 +57,8 @@ struct saturates // NOTE: this function does not return SaturateType value // it is user's responsiblity to do further cast or not template - CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const - -> std::enable_if_t, AccType> + CK_TILE_HOST_DEVICE constexpr auto + operator()(const AccType& a_) const -> std::enable_if_t, AccType> { return clamp(a_, type_convert(numeric::lowest()), diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 4a9748fcbb..aa5afd25e5 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" #include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" @@ -37,6 +38,7 @@ #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_topk.hpp" +#include "ck_tile/host/reference/reference_transpose.hpp" #include "ck_tile/host/rotating_buffers.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_utils.hpp" diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp index c68b908149..e9ba9a7d7b 100644 --- a/include/ck_tile/host/concat.hpp +++ b/include/ck_tile/host/concat.hpp @@ -33,13 +33,14 @@ struct IsCharArray : std::true_type }; template -inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v || - IsCharArray::value || - std::is_same_v)&&...); +inline constexpr bool AllConvertibleToStringView = + ((std::is_convertible_v || IsCharArray::value || + std::is_same_v) && + ...); template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { using ::operator<<; thread_local std::ostringstream oss; @@ -78,8 +79,8 @@ template } template -auto concatInto(std::string& result, const Ts&... xs) - -> std::enable_if_t, void> +auto concatInto(std::string& result, + const Ts&... xs) -> std::enable_if_t, void> { const std::size_t space = (1 + ... + getSize(xs)); result.reserve(result.size() + space); @@ -87,8 +88,8 @@ auto concatInto(std::string& result, const Ts&... xs) } template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { std::string result; concatInto(result, xs...); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 4a359e031f..e03881a1c7 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -63,7 +64,7 @@ struct FillUniformDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -92,6 +93,60 @@ struct FillUniformDistribution } }; +template <> +struct FillUniformDistribution +{ + float a_{-8.f}; // same type as primary template so that + // `FillUniformDistribution{-5.0f, 5.0f}` works for all types + float b_{7.f}; + std::optional seed_{11939}; + template + void operator()(ForwardIter first, ForwardIter last) const + { + if(a_ < -8.0f || b_ > 7.0f) + { + throw std::runtime_error( + "a_ or b_ of FillUniformDistribution is out of range."); + } + + int min_value = static_cast(a_); + int max_value = static_cast(b_); + constexpr auto int4_array = std::array{0x88, + 0x99, + 0xaa, + 0xbb, + 0xcc, + 0xdd, + 0xee, + 0xff, + 0x00, + 0x11, + 0x22, + 0x33, + 0x44, + 0x55, + 0x66, + 0x77}; + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_int_distribution dis(0, max_value - min_value + 1); + while(first != last) + { + int randomInt = dis(gen); + *first = int4_array[randomInt + (min_value + 8)]; + ++first; + } + } + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + namespace impl { // clang-format off @@ -187,7 +242,7 @@ struct FillNormalDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::normal_distribution dis(mean_, std::sqrt(variance_)); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -352,9 +407,10 @@ struct FillStepRange } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -373,9 +429,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -457,9 +514,10 @@ struct FillTrigValue } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index ecbc009b85..c3f1b7d221 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -378,7 +378,7 @@ struct HostTensor ~HostTensor() = default; HostTensor& operator=(const HostTensor&) = default; - HostTensor& operator=(HostTensor&&) = default; + HostTensor& operator=(HostTensor&&) = default; template explicit HostTensor(const HostTensor& other) : HostTensor(other.template CopyAsType()) diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index a822f967dc..a42b567fb4 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -15,7 +15,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp index 65303279b8..3e174bf870 100644 --- a/include/ck_tile/host/reference/reference_elementwise.hpp +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index c88deaec01..70ca44170e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -11,6 +11,110 @@ namespace ck_tile { +template +CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0, v_block_acc = 0; + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + } + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); + } + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_block_acc += v_a * v_b; + + // Apply group dequant scale + if((k + 1) % QuantGroupSize == 0) + { + float scale = 0.f; + index_t outer_dim = (aquant) ? m : k / QuantGroupSize; + index_t inner_dim = (aquant) ? k / QuantGroupSize : n; + + if constexpr(std::is_same_v) + { + scale = q(outer_dim, inner_dim); + } + else if constexpr(std::is_same_v) + { + scale = fp8_to_float_raw(q(outer_dim, inner_dim)); + } + else if constexpr(std::is_same_v) + { + scale = bf8_to_float_raw(q(outer_dim, inner_dim)); + } + else + { + static_assert(false, "Unexpected Q datatype."); + } + v_block_acc *= scale; + v_acc += v_block_acc; + v_block_acc = 0; + } + } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void +reference_grouped_conv_bwd_weight(const HostTensor& input, + HostTensor& weight, + const HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto k, auto c, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo) + { + auto wi = static_cast(wo * conv_strides[0]) + + static_cast(x * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + if(wi >= 0 && ck_tile::type_convert(wi) < input.get_lengths()[3]) + { + InDataType v_in = input(g, n, c, wi); + OutDataType v_out = output(g, n, k, wo); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto k, auto c, auto y, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho) + { + auto hi = static_cast(ho * conv_strides[0]) + + static_cast(y * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo) + { + auto wi = static_cast(wo * conv_strides[1]) + + static_cast(x * conv_dilations[1]) - + static_cast(in_left_pads[1]); + + if(hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[3] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[4]) + { + InDataType v_in = input(g, n, c, hi, wi); + OutDataType v_out = output(g, n, k, ho, wo); + + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + } + WeiDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, y, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3], + weight.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_) + { + auto di = static_cast(do_ * conv_strides[0]) + + static_cast(z * conv_dilations[0]) - + static_cast(in_left_pads[0]); + for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho) + { + auto hi = static_cast(ho * conv_strides[1]) + + static_cast(y * conv_dilations[1]) - + static_cast(in_left_pads[1]); + for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo) + { + auto wi = static_cast(wo * conv_strides[2]) + + static_cast(x * conv_dilations[2]) - + static_cast(in_left_pads[2]); + if(di >= 0 && + ck_tile::type_convert(di) < input.get_lengths()[3] && + hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[4] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[5]) + { + InDataType v_in = input(g, n, c, di, hi, wi); + OutDataType v_out = output(g, n, k, do_, ho, wo); + + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + } + } + WeiDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, z, y, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3], + weight.get_lengths()[4], + weight.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error( + "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 1e877b9933..b7615d0478 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -9,7 +9,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, diff --git a/include/ck_tile/host/reference/reference_transpose.hpp b/include/ck_tile/host/reference/reference_transpose.hpp new file mode 100644 index 0000000000..45d3dc9efa --- /dev/null +++ b/include/ck_tile/host/reference/reference_transpose.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_transpose_elementwise(const HostTensor& a, HostTensor& b) +{ + ck_tile::index_t M = static_cast(a.mDesc.get_lengths()[0]); + ck_tile::index_t N = static_cast(a.mDesc.get_lengths()[1]); + + // Ensure the b tensor is sized correctly for N x M + if(static_cast(b.mDesc.get_lengths()[0]) != N || + static_cast(b.mDesc.get_lengths()[1]) != M) + { + throw std::runtime_error("Output tensor b has incorrect dimensions for transpose."); + } + + auto f = [&](auto i, auto j) { + auto v_a = a(i, j); + b(j, i) = ck_tile::type_convert(v_a); + }; + + make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 200e2a618c..ca0088c812 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -4,6 +4,10 @@ #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index 4c3aa2ba29..a89a190489 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -32,7 +32,7 @@ struct BatchedTransposeKernel using Pipeline = remove_cvref_t; using Problem = remove_cvref_t; - using Type = typename Problem::InputType; + using Type = typename Problem::DataType; struct BatchedTransposeKargs { @@ -67,7 +67,7 @@ struct BatchedTransposeKernel return k; } - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; } CK_TILE_DEVICE void operator()(Kargs kargs) const { diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp new file mode 100644 index 0000000000..e344c24bf5 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +struct BatchedTransposeCommonPolicy +{ + CK_TILE_DEVICE static constexpr auto TileAccessPattern = + tile_distribution_pattern::thread_raked; + + template + CK_TILE_DEVICE static constexpr auto MakeInputDistribution() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t LeadDimPerBlock = Problem::kMPerBlock; + constexpr index_t SecondDimPerBlock = Problem::kNPerBlock; + + constexpr index_t kVectorSize = Problem::VectorSizeOutput; + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp new file mode 100644 index 0000000000..ef0b7fa229 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct BatchedTransposeLdsPipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; + static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; + + static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } + + CK_TILE_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, + OutputTileWindow& output_window) + { + __shared__ char smem[GetSmemSize()]; + auto input_tile_window = + make_tile_window(input_window, Policy::template MakeInputDistribution()); + auto output_tile_window = + make_tile_window(output_window, Policy::template MakeOutputDistribution()); + + DataType* p_lds_ptr = reinterpret_cast(smem); + constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); + auto input_lds_block = + make_tensor_view(p_lds_ptr, in_lds_block_desc); + + constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); + auto output_lds_block = + make_tensor_view(p_lds_ptr, out_lds_block_desc); + + auto copy_to_lds_window = + make_tile_window(input_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + auto load_from_lds_window = + make_tile_window(output_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeLdsLoadTileDistribution()); + + auto x = load_tile(input_tile_window); + + store_tile(copy_to_lds_window, x); + block_sync_lds(); + + auto y = load_tile_transpose(load_from_lds_window); + + store_tile(output_tile_window, y); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp similarity index 63% rename from example/ck_tile/37_transpose/transpose_policy.hpp rename to include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp index ea1a4130fe..77c3db9c06 100644 --- a/example/ck_tile/37_transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp @@ -1,24 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct TransposePolicy +struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy { - static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked; - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize() - { - return 16 / sizeof(typename Problem::DataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return integer_least_multiple( sizeof(typename Problem::DataType) * @@ -27,40 +20,24 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock; - constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType); - - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr auto input_dstr = MakeLdsLoadTileDistribution(); using OutTileDstrEncode = - typename OutputTileDistributionTraits, - typename Problem::DataType>::OutDstrEncode; + typename OutputTileDistributionTraits::TransposedDstrEncode; constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{}); return block_dstr; } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -82,12 +59,11 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - - constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -109,29 +85,25 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution() { using DataType = typename Problem::DataType; - // Extract base dimensions from the traits - constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits::kleadDim; - constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits::ksecondDim; - // Calculate block-level dimensions - constexpr index_t kLead = Problem::kLeadSizePerXdl; - constexpr index_t kSecond = Problem::kSecondSizePerXdl; - constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp; - constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp; + constexpr index_t kLeadIterPerWarp = 1; + constexpr index_t kSecondIterPerWarp = 1; constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps; constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps; // Calculate repetitions of base pattern - constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim; - constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim; + constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim; + constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim; constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim; constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations; + constexpr index_t kLaneGroupSize = 16; constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode + typename NumWarps, + bool kPadM_, + bool kPadN_> +struct BatchedTransposeLdsProblem +{ + using DataType = remove_cvref_t; + + static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{}); + static constexpr index_t kColWarps_ = NumWarps::at(number<0>{}); + static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; + static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{}); + static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{}); + + static constexpr index_t kBlockSize = kBlockSize_; + // warps per block + static constexpr index_t kLeadNumWarps = kRowWarps_; + static constexpr index_t kSecondNumWarps = kColWarps_; + + static constexpr index_t kLeadSizePerBlock = kRowPerBlock_; + static constexpr index_t kSecondSizePerBlock = kColPerBlock_; + + static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; + static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; + + static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, + "block dim should be divided by warp count!"); + static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, + "block dim should be divided by warp count!"); + // rows/cols per warp + static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; + static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; + + static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0, + "xdl dim should be divided by quad dim!"); + static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0, + "xdl dim should be divided by quad dim!"); + // xdl rows/cols is divided into quadrants. + static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim; + static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim; + + static constexpr index_t kIterationsInSecondDim = + kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); + + // definitions to adapt to BatchedTransposeKernel + + // FIXME: support padding + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + + static constexpr auto kMPerBlock = kLeadSizePerBlock; + static constexpr auto kNPerBlock = kSecondSizePerBlock; + + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index e815313c06..633827f3c3 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -5,8 +5,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" -#include -#include namespace ck_tile { @@ -14,15 +12,8 @@ template struct BatchedTransposePipeline { // TODO: this kernel only support warp per row - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using InputType = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t AlignmentM = Problem::AlignmentM; - static constexpr index_t AlignmentN = Problem::AlignmentN; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; template CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window) @@ -32,7 +23,7 @@ struct BatchedTransposePipeline auto input_tile = load_tile(inp_win); - auto output_tile = make_static_distributed_tensor( + auto output_tile = make_static_distributed_tensor( Policy::template MakeOutputDistribution()); transpose_tile2d(output_tile, input_tile); diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index dd9a6d79a8..5238fecdc5 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -4,43 +4,25 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/softmax.hpp" -#include "ck_tile/ops/topk.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct BatchedTransposePolicy +struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy { template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::kMPerBlock; - constexpr index_t NPerBlock = Problem::kNPerBlock; - constexpr index_t VecLoadSize = Problem::VectorSizeInput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::kMPerBlock; constexpr index_t NPerBlock = Problem::kNPerBlock; constexpr index_t VecLoadSize = Problem::VectorSizeOutput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; + using TileEncodingPattern = TileDistributionEncodingPattern2D; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } }; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index fd5ea004b6..2be979723b 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -6,42 +6,31 @@ #include "ck_tile/core.hpp" #include -#define VectorLoadSize 16 - namespace ck_tile { -template // Sequence<... struct BatchedTransposeProblem { - using InputType = remove_cvref_t; + using DataType = remove_cvref_t; - static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); - static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); - - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); - - static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{}); static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); - static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; - static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; - - static constexpr index_t kBlockSize = - kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock; + static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size(); static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType); - static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType); + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 53187771b9..4858245ec4 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" +#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp new file mode 100644 index 0000000000..f9b1cf3352 --- /dev/null +++ b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 + type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bf16_t& x1) const + { + const float x1_tmp = type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const + { + const float x1_tmp = type_convert(x0); + const float x2_tmp = type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bf16_t& y, const float& x0, const bf16_t& x1) const + { + const float x2_tmp = type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 + x1; + }; +}; + +} // namespace element_wise +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp new file mode 100644 index 0000000000..103468c5fa --- /dev/null +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +namespace ck_tile { + +template +struct ElementWiseKernel +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using ElementWiseOperation = ck_tile::remove_cvref_t; + + template + CK_TILE_DEVICE void operator()(Dims lens, + Dims input_strides, + Dims output_strides, + const tuple& input_tensors, + YDataType* p_y) const + { + using S = typename Problem::BlockShape; + + // Setup block-level coordinates and transforms + const index_t iM = get_block_id() * S::kBlockM; + const auto merge_transform = make_merge_transform(lens); + + // Load all input tiles into registers. + // The lambda structure here is intended to minimize the lifetime + // of intermediate objects (views, windows) used for loading. + const auto x_tiles = ck_tile::generate_tuple( + [&](auto i) { + const auto tensor_view = make_naive_tensor_view( + input_tensors.get(i), lens, input_strides, number{}, number<1>{}); + + const auto transformed_tensor = pad_tensor_view( + transform_tensor_view(tensor_view, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); + + const auto x_window = + make_tile_window(transformed_tensor, + ck_tile::make_tuple(number{}), + {iM}, + Policy::template MakeXBlockTileDistribution()); + + return load_tile(x_window); + }, + number{}); + + // Setup output tile in registers. + const auto& x_tile0 = x_tiles.get(number<0>{}); + auto y_tile = make_static_distributed_tensor(x_tile0.get_tile_distribution()); + + // Perform element-wise computation. + const auto spans = x_tile0.get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx) { + const auto tile_idx = make_tuple(idx); + apply( + [&](auto&&... tiles) { + ElementWiseOperation{}(y_tile(tile_idx), + type_convert(tiles[tile_idx])...); + }, + x_tiles); + }); + + // Setup output window and store the result tile. + const auto y_m_n = make_naive_tensor_view( + p_y, lens, output_strides, number{}); + + const auto transformed_y_m_n = pad_tensor_view( + transform_tensor_view(y_m_n, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); + + auto y_window = make_tile_window(transformed_y_m_n, + make_tuple(number{}), + {iM}, + y_tile.get_tile_distribution()); + + store_tile(y_window, cast_tile(y_tile)); + } + + template + CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple& input_sizes) + { + int total_elements = 1; + const auto kVectorM = Problem_::BlockShape::kVectorM; + + apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes); + + if((total_elements % kVectorM) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: total number of input elements (", + total_elements, + ") should be multiple of the vectorization size (", + kVectorM, + ")"); + } + return false; + } + + return true; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp new file mode 100644 index 0000000000..9cba43d350 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +struct ElementWiseDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding, // Replicate + tuple>, // Hierarchical + tuple, sequence<1>>, // Parallel + tuple, sequence<2>>, // Parallel + sequence<1, 1>, // Yield + sequence<0, 3>>{} // Yield + ); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp new file mode 100644 index 0000000000..a5d00ee1d0 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct ElementWisePipelineProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ElementWiseOperation = remove_cvref_t; + static constexpr bool kPad = kPad_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp new file mode 100644 index 0000000000..0d25a8a202 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct ElementWiseShape +{ + static constexpr index_t kBlockM = BlockTile::at(number<0>{}); + + static constexpr index_t kWarpM = WarpTile::at(number<0>{}); + + static constexpr index_t kVectorM = 16 / sizeof(ComputeDataType); + + static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{}); + + static constexpr index_t kThreadPerWarpM = kWarpM / kVectorM; + + static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kWarpM); + + static constexpr index_t kBlockSize = + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index a3fe5045cf..0e385901ed 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -110,6 +110,86 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) return res; } +CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) +{ + uint32_t src = static_cast(a), src_hi; + uint32_t fp8x4_lo, fp8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(fp8x4_lo), + [v_dst_hi] "+v"(fp8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(fp8x4_hi) << 32) | fp8x4_lo)); +} + +CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a) +{ + uint32_t src = static_cast(a), src_hi; + uint32_t bf8x4_lo, bf8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(bf8x4_lo), + [v_dst_hi] "+v"(bf8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(bf8x4_hi) << 32) | bf8x4_lo)); +} + struct PassThroughPack8 { template @@ -126,6 +206,16 @@ struct PassThroughPack8 y.lo = i4_to_bhalf4(bit_cast(x)); y.hi = i4_to_bhalf4(bit_cast(x) >> 16); } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_fp8x8(bit_cast(x)); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_bf8x8(bit_cast(x)); + } constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index bf58544259..d42f144baa 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -69,6 +69,8 @@ struct CShuffleEpilogue using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + using ATypeToUse = + std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; @@ -201,7 +203,7 @@ struct CShuffleEpilogue static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher, - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -334,8 +336,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 1dcd62011a..23c4ad583e 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,7 +73,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 sequence<2, 1>, // !! note here is different sequence<0, 0>>{}; - using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); diff --git a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp index 0b812875c4..037bb7688c 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -49,7 +49,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base sequence<2, 1>, // !! note here is different sequence<0, 0>>{}; - using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index edb5853c7f..54f2a777bf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -458,7 +458,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 837aeb13e3..cc00000efc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -431,12 +431,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, diff --git a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp index 551d390ec6..0e98078d53 100644 --- a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp @@ -29,6 +29,9 @@ struct TileFlatmmShape static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN); static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK); + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + CK_TILE_HOST static std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index f21136d2a8..30bea193b7 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -15,9 +15,9 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -29,14 +29,14 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 561e5fb00a..8d257a3329 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -955,9 +955,9 @@ struct FmhaFwdKernel else { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, batch_size_); } } @@ -1003,8 +1003,8 @@ struct FmhaFwdKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; + const index_t i_block = blockIdx.y; // blockIdx.x + const index_t i_nhead = blockIdx.x; // blockIdx.y const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { @@ -1018,7 +1018,7 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 10daea99d1..6398bf316e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -122,9 +122,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; - // use larger K/V LDS buffer size will lower the occupancy - else if constexpr(64 <= kK0 || 64 <= kK1) - return 1; else return 2; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index d1b6e6f85b..420ae03b7e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -738,6 +738,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); @@ -976,6 +981,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP decltype(ds_gemm)>(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 0b8e5836cd..3489d6f9a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -509,7 +509,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsStoreBlockDescriptor(number = number<0>{}) + MakeKLdsStoreBlockDescriptor(number = number<0>{}) { // K is always k-major, we use async-copy to load into LDS constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 76ba34115f..570cff8bf0 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -60,8 +60,8 @@ struct TileFmhaShape // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; using VLayout = std::conditional_t; + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor>; }; template (kargs.o_ptr); auto o_view_ = make_naive_tensor_view( + memory_operation_enum::atomic_add>( o_ptr, make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.stride_token, 1), diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 5da675ae42..a5f9f31d6a 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -13,7 +13,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) #ifndef MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_USE_EX_KERNEL 1 @@ -23,6 +23,11 @@ namespace ck_tile { #define MOE_SORTING_FUSE_MP_01 0 #endif +// weather use 2d buffer indexing for fmoe ws or 1d +#ifndef MOE_SORTING_FMOE_2D_BUF +#define MOE_SORTING_FMOE_2D_BUF 1 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -171,7 +176,7 @@ struct MoeSortingHostArgs void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; - void* p_total_tokens_post_pad; + void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens) // we fused the setzero of output of fused-moe buffer // set this pointer to nullptr will skip this operation void* p_moe_buf; @@ -182,7 +187,18 @@ struct MoeSortingHostArgs index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; index_t topk; +#if MOE_SORTING_FMOE_2D_BUF + // NOTE: + // moe_buf_* is a 2d ws buffer used for the following fmoe kernel + // arranged as row*col, where row=tokens(or local_token), col=interm_dim + // we fuse this clearing inside sorting kernel + // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe) + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else long_index_t moe_buf_bytes; // byte size of p_moe_buf +#endif + }; template @@ -197,6 +213,9 @@ struct MoeSortingKernel using Hargs = MoeSortingHostArgs; + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + struct Kargs { const void* p_topk_ids; @@ -210,8 +229,12 @@ struct MoeSortingKernel void* p_moe_buf; index_t tokens; index_t num_experts; +#if MOE_SORTING_FMOE_2D_BUF + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else long_index_t moe_buf_bytes; - +#endif index_t tokens_per_thread; index_t smem_rows; mdiv unit_size_mdiv; @@ -220,10 +243,27 @@ struct MoeSortingKernel // mdiv sub_tokens_mdiv; }; + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + (void)h; + return get_num_cu() * OCCUPANCY; +#else // TODO: assume num-experts not too much return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) @@ -263,7 +303,12 @@ struct MoeSortingKernel k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; k.tokens = h.tokens; k.num_experts = h.num_experts; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif const auto blocks = BlockSize(h); // NOTE: tokens could from p_local_tokens, so here this variable is useless @@ -431,6 +476,24 @@ struct MoeSortingKernel } } + CK_TILE_DEVICE void + moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const + { + const long_index_t total_pixels = static_cast(row) * col; + const long_index_t total_bytes = total_pixels * elem_bytes; + const long_index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_buf = reinterpret_cast(buf); + auto zero_ = vector_type{0}; + + for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems; + i += (gridDim.x - 1) * BLOCK_SIZE) + { + p_buf[i] = zero_; + } + } + CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, const WeightType* __restrict__ weights, index_t* p_sorted_token_ids, @@ -863,7 +926,8 @@ struct MoeSortingKernel } if((lid + i_e_ - get_warp_size()) == (num_experts - 1)) { - *p_total_tokens_post_pad = local_cumsum_; + *p_total_tokens_post_pad = local_cumsum_; + p_total_tokens_post_pad[1] = tokens; } } __syncthreads(); @@ -1005,20 +1069,6 @@ struct MoeSortingKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if(blockIdx.x > 0) - { - if(kargs.p_moe_buf) - { - moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), - kargs.moe_buf_bytes); - } - return; - } - const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; - extern __shared__ char smem[]; - -#if MOE_SORTING_USE_EX_KERNEL - (void)numel; index_t tokens_ = [&]() { if constexpr(Problem::LocalToken) { @@ -1029,6 +1079,25 @@ struct MoeSortingKernel return kargs.tokens; } }(); + + if(blockIdx.x > 0) + { + if(kargs.p_moe_buf) + { +#if MOE_SORTING_FMOE_2D_BUF + moe_buf_set_zero_kernel_2d( + kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes); +#else + moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes); +#endif + } + return; + } + + extern __shared__ char smem[]; + +#if MOE_SORTING_USE_EX_KERNEL return moe_align_block_size_kernel_ex( static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), @@ -1045,6 +1114,7 @@ struct MoeSortingKernel kargs.smem_rows, smem); #else + const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), static_cast(kargs.p_sorted_token_ids), @@ -1066,6 +1136,8 @@ namespace impl { // [expert, padded_tokens] CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) { + // Pad to multiply of 32. This can make sure even if the mesh is in 8bit, + // we can still use dwordx4 load/store constexpr index_t chunk = 32; return (tokens + chunk - 1) / chunk * chunk; }; @@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by } } +template +CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d( + void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks) +{ + const long_index_t total_pixels = static_cast(row) * col; + const long_index_t total_bytes = total_pixels * elem_bytes; + const long_index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_buf = reinterpret_cast(buf); + auto zero_ = vector_type{0}; + + for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE) + { + p_buf[i] = zero_; + } +} + } // namespace impl // TODO: tokens could be from @@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe } // return size in byte -CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_) +// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp +// kernel +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, + int num_experts_, + int topk_, + int dispatch_policy_) { #if 1 - if(moe_sorting_is_oneshot(tokens_, num_experts_)) + // return 0; + if(dispatch_policy_ == 0) { - return 0; + if(moe_sorting_is_oneshot(tokens_, num_experts_)) + { + return 0; + } + else + { + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); + } + } + else if(dispatch_policy_ == 1) + { + return 0; // always use single kernel } else { @@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts #endif } +template +struct MoeSortingClearWorkspaceKernel +{ + using Problem = remove_cvref_t; + static constexpr index_t BLOCK_SIZE = Problem::BlockSize; + static constexpr index_t OCCUPANCY = Problem::Occu; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens + void* p_expert_mesh; // [expert, tokens] + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens + // used for ws/LDS calculation + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t mesh_byte_size; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk); + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + + index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens); + index_t pixels = kargs.num_experts * row_size; + index_t total_bytes = pixels * kargs.mesh_byte_size; + index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + auto zero_ = vector_type{0}; + + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems; + i += gridDim.x * BLOCK_SIZE) + { + p_expert_mesh[i] = zero_; + } + } +}; + // below kernel is multi-phase implementation for large token and/or expert case // write into a buffer to record the token cnt @@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0 else return tokens; }(); + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile @@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0 if constexpr(Problem::LocalToken) { if(static_cast(curr_token_id) < tokens) - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; } else - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = - (curr_topk_id + 1) & 0xffff; + p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; }); } } @@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1 struct Kargs { const void* p_local_expert_mask; // [expert] + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; index_t mesh_stride; // mesh_stride for p_expert_mesh @@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_expert_mesh = h.p_ws; k.p_expert_cumsum = reinterpret_cast( reinterpret_cast(h.p_ws) + @@ -1511,12 +1721,9 @@ struct MoeSortingMultiPhaseKernel_P1 { __shared__ char smem[GetSmemSize()]; - int eid = blockIdx.x; - + int eid = blockIdx.x; constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 - r_t* p_expert_mesh = reinterpret_cast( - reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); @@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1 auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return 0; // will not use if not LocalToken + } + }(); + + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride); + + int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; if constexpr(Problem::LocalExpertMasking) { @@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1 { int position = i * BLOCK_SIZE + threadIdx.x; r_t v{0}; - if(position < (kargs.mesh_stride / index_pack)) + if(position < (mesh_stride / index_pack)) v = p_expert_mesh[position]; index_t local_sum = 0; static_for<0, index_pack, 1>{}( @@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2 const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] - void* p_total_tokens_post_pad; // [1] + void* p_total_tokens_post_pad; // [2] void* p_sorted_expert_ids; void* p_moe_buf; index_t tokens; @@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif return k; } + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + return dim3(h.num_experts + get_num_cu() * OCCUPANCY); +#else // use 1 block to cumsum return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } @@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2 { if(blockIdx.x > 0) { +#if MOE_SORTING_FMOE_2D_BUF + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + kargs.tokens, + kargs.moe_buf_interm_dim, + kargs.moe_buf_elem_bytes, + blockIdx.x - 1, + gridDim.x - 1); + return; +#else impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - 1); return; +#endif } __shared__ char smem[GetSmemSize()]; IndexType* s = reinterpret_cast(smem); @@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23 const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] - void* p_total_tokens_post_pad; // [1] + void* p_total_tokens_post_pad; // [2] void* p_sorted_expert_ids; void* p_sorted_token_ids; @@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23 index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv unit_size_mdiv; mdiv topk_mdiv; - long_index_t moe_buf_bytes; +#if MOE_SORTING_FMOE_2D_BUF + // NOTE: + // moe_buf_* is a 2d ws buffer used for the following fmoe kernel + // arranged as row*col, where row=tokens(or local_token), col=interm_dim + // we fuse this clearing inside sorting kernel + // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe) + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else + long_index_t moe_buf_bytes; // byte size of p_moe_buf +#endif }; CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23 k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif return k; } + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + return dim3(h.num_experts + get_num_cu() * OCCUPANCY); +#else // use 1 block to cumsum // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } @@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23 // reduce single pixel within a wave CK_TILE_DEVICE void operator()(Kargs kargs) const { + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + if(static_cast(blockIdx.x) >= kargs.num_experts) { +#if MOE_SORTING_FMOE_2D_BUF + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + tokens, + kargs.moe_buf_interm_dim, + kargs.moe_buf_elem_bytes, + blockIdx.x - kargs.num_experts, + gridDim.x - kargs.num_experts); + return; +#else impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - kargs.num_experts); return; +#endif } extern __shared__ char smem[]; @@ -2428,13 +2743,15 @@ struct MoeSortingMultiPhaseKernel_P23 { auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; if(blockIdx.x == 0) + { p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_total_tokens_post_pad[1] = tokens; + } p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad; } } __syncthreads(); - { const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); @@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23 return; // skip empty expert } - index_t tokens = [&]() { + index_t mesh_stride = [&]() { if constexpr(Problem::LocalToken) { - return reinterpret_cast(kargs.p_local_tokens)[0]; + return impl::moe_sorting_mp_mesh_stride(tokens); } else { - return kargs.tokens; + return kargs.mesh_stride; } }(); @@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23 constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 using d_t = ext_vector_t; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; for(int i = 0; i < loops; i++) @@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23 r_t x_v = 0; if(i_token_pack < (tokens + index_pack - 1) / index_pack) { - x_v = reinterpret_cast(p_expert_mesh + - eid * kargs.mesh_stride)[i_token_pack]; + x_v = reinterpret_cast(p_expert_mesh + eid * mesh_stride)[i_token_pack]; } r_t x_r; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index 181266d7af..ea218b9c25 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -73,4 +73,12 @@ struct MoeSortingProblemMp SubTokenTile == 8 || SubTokenTile == 16); }; +template +struct MoeSortingClearWorkspaceProblem +{ + static constexpr bool LocalToken = LocalToken_; + static constexpr index_t BlockSize = BlockSize_; + static constexpr index_t Occu = Occu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp index e9577e2304..17c38a2632 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp @@ -267,8 +267,7 @@ struct FusedMoeGemmPipeline_FlatmmEx statically_indexed_array as; auto gld_a = [&]>( - auto& a_store_, auto i_access, PreNop = {}) - { + auto& a_store_, auto i_access, PreNop = {}) { async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); }; auto move_a = [&]() { @@ -278,43 +277,40 @@ struct FusedMoeGemmPipeline_FlatmmEx load_tile_raw(a_, win_, i_access); }; - auto gld_g = [&]>( - auto& g_, auto i_access, PreNop = {}) - { - if constexpr(IsGateOnly) - { - // TODO: hack! - if constexpr(i_access.value == 0) + auto gld_g = + [&]>(auto& g_, auto i_access, PreNop = {}) { + if constexpr(IsGateOnly) { - g_win.bottom_tensor_view_ = g_view; + // TODO: hack! + if constexpr(i_access.value == 0) + { + g_win.bottom_tensor_view_ = g_view; + } + else if constexpr(i_access.value == issues_g / 2) + { + g_win.bottom_tensor_view_ = u_view; + } } - else if constexpr(i_access.value == issues_g / 2) - { - g_win.bottom_tensor_view_ = u_view; - } - } - load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); - }; + load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); + }; auto move_g = [&]() { move_tile_window(g_win, {number<0>{}, number{}, number<0>{}}); }; statically_indexed_array ds; - auto gld_d = [&]>( - auto& d_, auto i_access, PreNop = {}) - { - load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); - }; + auto gld_d = + [&]>(auto& d_, auto i_access, PreNop = {}) { + load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); + }; auto move_d = [&]() { // d move along gemm-n move_tile_window(d_win, {number{}, number<0>{}}); }; - auto atomic_add_o = [&]>( - auto& o_, auto i_access, PreNop = {}) - { - update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); - }; + auto atomic_add_o = + [&]>(auto& o_, auto i_access, PreNop = {}) { + update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); + }; auto acc_0 = Policy::template MakeCBlockTile_Gemm0(); auto acc_1s = generate_tuple( diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index f1e8bcc0a8..c201293389 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -13,9 +13,9 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" @@ -28,8 +28,10 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" @@ -44,10 +46,10 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index 8dd1d1ec28..cfbd78967f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -15,6 +16,19 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { +#if defined(__gfx950__) + constexpr bool is_a_load_tr = std::is_same_v, + tensor_layout::gemm::ColumnMajor>; + constexpr bool is_b_load_tr = std::is_same_v, + tensor_layout::gemm::RowMajor>; +#else + constexpr bool is_a_load_tr = false; + constexpr bool is_b_load_tr = false; +#endif + constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -33,21 +47,41 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2); } else { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2); } #else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + using WG = WarpGemmMfmaDispatcher; + return make_tuple(WG{}, 4, 1); #endif } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + using WG = WarpGemmMfmaDispatcher; + return make_tuple(WG{}, 4, 1); } else { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index d4e23d12dd..e1b0792ecf 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -218,10 +218,16 @@ struct BlockUniversalGemmAsBsCr BLdsTile b_warp_tile_; // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -300,14 +306,23 @@ struct BlockUniversalGemmAsBsCr ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { if constexpr(std::is_same_v) { load_interleaved_pk_type(a_warp_tile_, a_block_window); } + else if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } else { load_tile(a_warp_tile_, a_block_window); @@ -316,6 +331,10 @@ struct BlockUniversalGemmAsBsCr { load_interleaved_pk_type(b_warp_tile_, b_block_window); } + else if constexpr(BLoadTranspose) + { + b_warp_tile_ = load_tile_transpose(b_block_window); + } else { load_tile(b_warp_tile_, b_block_window); @@ -323,10 +342,16 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, [[maybe_unused]] ASmemBlockWindow& a_block_window, - [[maybe_unused]] BSmemBlockWindow& b_block_window) + [[maybe_unused]] BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -382,40 +407,73 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + make_static_tile_distribution(MakeABlockDistributionEncode()); static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + make_static_tile_distribution(MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; auto a_lds_gemm_window = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {0, KIdx * KPerInnerLoop}, - a_lds_load_tile_distr); + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); auto b_lds_gemm_window = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {0, KIdx * KPerInnerLoop}, - b_lds_load_tile_distr); + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); if constexpr(std::is_same_v) { load_interleaved_pk_type(a_warp_tile_, a_block_window); } + else if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_lds_gemm_window); + } else { load_tile(a_warp_tile_, a_lds_gemm_window); @@ -424,6 +482,10 @@ struct BlockUniversalGemmAsBsCr { load_interleaved_pk_type(b_warp_tile_, b_block_window); } + else if constexpr(BLoadTranspose) + { + b_warp_tile_ = load_tile_transpose(b_lds_gemm_window); + } else { load_tile(b_warp_tile_, b_lds_gemm_window); @@ -431,10 +493,16 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " @@ -442,7 +510,7 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { - LocalPrefetch(a_block_window, b_block_window); + LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); __builtin_amdgcn_sched_barrier(0); // NOTE: Synchronize threads in a workgroup at the start of each MAC // cluster, but except the first, as we can shorten non-MAC cluster a bit @@ -543,29 +611,45 @@ struct BlockUniversalGemmAsBsCr return c_block_tensor; } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); } // C = A * B - template + template CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { auto c_block_tensor = MakeCBlockTile(); - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); return c_block_tensor; } diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 09c7d58558..9c1ce73eac 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -9,35 +9,41 @@ namespace ck_tile { -struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs +/// @brief The Batched GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<> { - CK_TILE_HOST BatchedGemmHostArgs() = default; - CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ck_tile::index_t batch_stride_A_, - ck_tile::index_t batch_stride_B_, - ck_tile::index_t batch_stride_C_, - ck_tile::index_t batch_count_) - : GemmHostArgs(a_ptr_, - b_ptr_, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - stride_A_, - stride_B_, - {}, - stride_C_), + CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_, + ck_tile::index_t batch_stride_A_, + ck_tile::index_t batch_stride_B_, + ck_tile::index_t batch_stride_C_, + ck_tile::index_t batch_count_) + : UniversalGemmHostArgs<>({a_ptr_}, + {b_ptr_}, + {/*ds_ptr*/}, + c_ptr_, + k_batch_, + M_, + N_, + K_, + {stride_A_}, + {stride_B_}, + {/*stride_Ds_*/}, + stride_C_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_E(batch_stride_C_), @@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs }; template -struct BatchedGemmKernel : public GemmKernel +struct BatchedGemmKernel { - using Base = GemmKernel; + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; - using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; - using ADataType = typename Base::ADataType; - using BDataType = typename Base::BDataType; - using CDataType = typename Base::EDataType; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; - using TilePartitioner = typename Base::TilePartitioner; - using GemmPipeline = typename Base::GemmPipeline; - using EpiloguePipeline = typename Base::EpiloguePipeline; - using ALayout = typename Base::ALayout; - using BLayout = typename Base::BLayout; - using CLayout = typename Base::ELayout; + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - return concat('_', "gemm_batched", gemm_prec_str, - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); - // clang-format on - } + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - struct BatchedGemmKernelArgs : GemmKernelArgs + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + + struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { index_t batch_stride_A; index_t batch_stride_B; @@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel const std::string + { + // clang-format off + using P_ = GemmPipeline; + return concat('_', "gemm_batched", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + + CK_TILE_HOST static constexpr auto + GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3 { return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch); } - __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return dim3(UniversalGemmKernel::KernelBlockSize); + } CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs& hostArgs) { - return BatchedGemmKernelArgs{{hostArgs.a_ptr, - hostArgs.b_ptr, - {}, + return BatchedGemmKernelArgs{{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, hostArgs.e_ptr, hostArgs.M, hostArgs.N, hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - {}, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, hostArgs.stride_E, hostArgs.k_batch}, hostArgs.batch_stride_A, @@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); @@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel(kargs.a_ptr) + batch_offset_A + - splitk_batch_offset.a_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + batch_offset_A + + splitk_batch_offset.as_k_split_offset[0]; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B + - splitk_batch_offset.b_k_split_offset; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + batch_offset_B + + splitk_batch_offset.bs_k_split_offset[0]; const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); @@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + UniversalGemmKernel::RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 516d4298ef..079d3972d1 100755 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,6 +12,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -24,14 +25,11 @@ namespace ck_tile { /// and launch kernel on GPU. /// This structure defines the GEMM problem configuration by stating all required information /// like M,N,K sizes and respective strides. -/// NumDTensor describes the number of D tensors. -template struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - const std::array& ds_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, @@ -39,18 +37,15 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - const std::array& stride_Ds_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), - ds_ptr(ds_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), - stride_Ds(stride_Ds_), stride_E(stride_E_), k_batch(k_batch_) { @@ -58,18 +53,18 @@ struct GemmHostArgs const void* a_ptr; const void* b_ptr; - const std::array ds_ptr; union { void* e_ptr; void* c_ptr; }; + index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; - const std::array stride_Ds; + union { index_t stride_E; @@ -79,990 +74,96 @@ struct GemmHostArgs index_t k_batch; }; -/// @brief The GEMM kernel device arguments. -template -struct GemmKernelArgs -{ - /// @brief The A input tensor's pointer to device memory. - const void* a_ptr; - /// @brief The B input tensor's pointer to device memory. - const void* b_ptr; - /// @brief The Ds input tensor's pointer to device memory. - const std::array ds_ptr; - /// @brief The E output tensor's pointer to device memory. - void* e_ptr; - /// @brief GEMM's M dimension size. - index_t M; - /// @brief GEMM's N dimension size. - index_t N; - /// @brief GEMM's K dimension size. - index_t K; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of A tensor. - index_t stride_A; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of B tensor. - index_t stride_B; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of Ds tensor. - std::array stride_Ds; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of E tensor. - index_t stride_E; - index_t k_batch; -}; - -/// @brief The GEMM kernel template. -/// -/// @paragraph Overview Overview -/// This class provides the generic matrix multiplication kernel template. By semantic -/// division of GEMM algorithm into following parts we achieve flexible, versatile -/// and robust kernel implementation. -/// -/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() -/// function call operator" which determines the work scope of each workgroup. -/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. -/// This is the place where each workgroup is loading data from global memory and -/// carrying out dot products. -/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation -/// responsible for storing results to global memory. This is also the place where -/// any additional operator fusion may take place. -/// -/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ -/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all -/// internal details of those functional parts. You can think of it like both gemm and -/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover -/// the policy is responsible for definition of all necessary data layouts and thread's -/// work distribution. -/// -/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the -/// output data tile to be calculated. It determines the workgroup to -/// data relationship (or in other words - which data would be -/// processed and calculated by which workgroup). -/// @tparam GemmPipeline_ The type of class which provides the core part of matrix -/// multiplication. This class should provide implementation of data -/// loading from global memory and performing block-wise matrix -/// multiplication. You can think of it as a work done by single -/// workgroup point of view. -/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix -/// multiplication implementation. It is responsible for storing -/// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output E tensor in global memory. template struct GemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - // Get the persistent kernel if the pipeline has it available - struct has_persistent_kernel - { - template - using has_persistent_type = decltype(T::UsePersistentKernel); - - static constexpr bool value = []() { - if constexpr(is_detected{}) - return GemmPipeline::UsePersistentKernel; - else - return false; - }(); - }; - static constexpr bool PersistentKernel = has_persistent_kernel::value; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; - static constexpr index_t NumDTensor = DsDataType::size(); + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - static constexpr auto I3 = number<3>{}; + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - static_assert(DsLayout::size() == DsDataType::size(), - "The size of DsLayout and DsDataType should be the same"); - using KernelArgs = GemmKernelArgs; + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); - [[nodiscard]] CK_TILE_HOST static const std::string GetName() + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + + CK_TILE_HOST static auto GetName() -> const std::string { - // clang-format off - return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); - // clang-format on + return UniversalGemmKernel::GetName(); } - CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 { - return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + return UniversalGemmKernel::GridSize(M, N, KBatch); } - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using Kernel = GemmKernel; - const auto kernel = kentry; - int occupancy; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); + return UniversalGemmKernel::MaxOccupancyGridSize(s); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - - CK_TILE_HOST static constexpr KernelArgs - MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { - - return KernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.ds_ptr, - hostArgs.e_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_Ds, - hostArgs.stride_E, - hostArgs.k_batch}; + return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs( + {hostArgs.a_ptr}, + {hostArgs.b_ptr}, + {/*hostArgs.ds_ptr*/}, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + {/*hostArgs.stride_Ds*/}, + hostArgs.stride_E)); } - struct SplitKBatchOffset + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) - { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); - - if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - else if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); - } - - if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); - } - else if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - - if(k_id < static_cast(kargs.k_batch - 1)) - { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); - } - else - { - splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); - } - } - - index_t a_k_split_offset; - index_t b_k_split_offset; - index_t splitted_k; - }; - - CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) - { - if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) - { - if(kargs.k_batch != 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // again k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - - bool DTesnorIsValid = {true}; - static_for<0, NumDTensor, 1>{}([&](auto index) { - using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) - { - DTesnorIsValid = false; - } - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " - "NPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " - "MPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - }); - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - return DTesnorIsValid; + return UniversalGemmKernel::IsSupportedArgument(kargs); } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / - TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_pad_view = views.at(I1); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - if constexpr(std::is_same_v) - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(a_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - template - CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_0, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); - - if(UseDefaultScheduler || (get_warp_id() == 0)) - { - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - // Non-persistent kernel entry point - template > - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const SplitKBatchOffset splitk_batch_offset(kargs); - - // options - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - // Persistent kernel entry point - template , typename = void> - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); - const auto num_tiles = - __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); - const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); - auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); - - while(block_id < num_work) - { - // Get the tile index for this block - const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - // Get the SplitK offset for this block - const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); - const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - // Advance to the next work item - block_id += grid_size; - if(block_id >= num_work) - { - break; - } - } + UniversalGemmKernel{}.template operator()(kargs); } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp new file mode 100644 index 0000000000..34340008d4 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. NumDTensor +/// describes the number of D tensors. +template +struct GemmMultiDHostArgs +{ + CK_TILE_HOST GemmMultiDHostArgs() = default; + CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ALayout and ADataType must be scalars."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "BLayout and BDataType must be scalars."); + + /// @brief ELayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ELayout and EDataType must be scalars."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by + /// the user." + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs({hostArgs.a_ptr}, + {hostArgs.b_ptr}, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 28e8bee908..0a6bacdc42 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -69,8 +69,8 @@ struct GemmTile2DPartitioner * @param blockIdy WGP's Y index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple { const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy); @@ -137,8 +137,8 @@ struct GemmTile1DPartitioner * @param blockIdx WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple { const index_t NBlocks = integer_divide_ceil(N_, NPerBlock); @@ -188,9 +188,8 @@ struct OffsettedTile1DPartitioner * @param [in] N Gemm's N dimension. * @return Returns a `tuple` [Im, In] with shifted index. */ - [[nodiscard]] CK_TILE_DEVICE static auto - GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept - -> const tuple + [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex( + index_t block_start, index_t M, index_t N) noexcept -> const tuple { const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start); return make_tuple(iM, iN); @@ -271,8 +270,8 @@ struct GemmSpatiallyLocalTilePartitioner * @param [in] block_1d_id WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept - -> const tuple + CK_TILE_DEVICE auto + GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple { const auto M0 = integer_divide_ceil(M, MPerBlock); const auto N0 = integer_divide_ceil(N, NPerBlock); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 533cabb736..921ea11720 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -16,37 +16,116 @@ namespace ck_tile { +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct GroupedGemmHostArgs +{ + CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + struct GemmTransKernelArg { - GemmKernelArgs<> group_karg; + UniversalGemmKernelArgs<> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = delete; - GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg) + : group_karg{karg}, block_start{0}, block_end{0} + { + } }; template -struct GroupedGemmKernel : public GemmKernel +struct GroupedGemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Base = GemmKernel; using Kernel = GroupedGemmKernel; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; @@ -57,7 +136,7 @@ struct GroupedGemmKernel : public GemmKernel, + return concat('_', "gemm_grouped", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK), @@ -66,7 +145,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) -> std::size_t + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -108,8 +186,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) - -> std::vector + MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; index_t group_count = ck_tile::type_convert(gemm_descs.size()); @@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel{type_convert(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - {}, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - stride_a, - stride_b, - {}, - stride_e, - gemm_descs[i].k_batch}; + auto karg = + UniversalGemmKernelArgs<>{{type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + {/*ds_ptr*/}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + {/*stride_ds*/}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + Base::RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); } } @@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + const UniversalGemmKernelArgs<>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel( - a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The Universal GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref UniversalGemmKernel "UniversalGemmKernel" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct UniversalGemmHostArgs +{ + CK_TILE_HOST UniversalGemmHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +/// @brief The GEMM kernel device arguments. +template +struct UniversalGemmKernelArgs +{ + /// @brief The As input tensor's pointer to device memory. + const std::array as_ptr; + /// @brief The Bs input tensor's pointer to device memory. + const std::array bs_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; + /// @brief GEMM's M dimension size. + index_t M; + /// @brief GEMM's N dimension size. + index_t N; + /// @brief GEMM's K dimension size. + index_t K; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of As tensor. + std::array stride_As; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Bs tensor. + std::array stride_Bs; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; + index_t k_batch; +}; + +/// @brief The Universal GEMM kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the generic matrix multiplication kernel template. By semantic +/// division of GEMM algorithm into following parts we achieve flexible, versatile +/// and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the +/// output data tile to be calculated. It determines the workgroup to +/// data relationship (or in other words - which data would be +/// processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of data +/// loading from global memory and performing block-wise matrix +/// multiplication. You can think of it as a work done by single +/// workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output E tensor in global memory. +template +struct UniversalGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = + is_detected::value; + static constexpr bool BDataTypeIsTuple = + is_detected::value; + static constexpr bool DDataTypeIsTuple = + is_detected::value; + static constexpr bool ALayoutIsTuple = + is_detected::value; + static constexpr bool BLayoutIsTuple = + is_detected::value; + static constexpr bool DLayoutIsTuple = + is_detected::value; + + using AsLayout = std::conditional_t, + remove_cvref_t>>; + using BsLayout = std::conditional_t, + remove_cvref_t>>; + + using DsLayout = std::conditional_t, + remove_cvref_t>>; + + using AsDataType = std::conditional_t, + remove_cvref_t>>; + + using BsDataType = std::conditional_t, + remove_cvref_t>>; + + using DsDataType = + std::conditional_t, + remove_cvref_t>>; + + using ELayout = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + // Get the persistent kernel if the pipeline has it available + struct has_persistent_kernel + { + template + using has_persistent_type = decltype(T::UsePersistentKernel); + + static constexpr bool value = []() { + if constexpr(is_detected{}) + return GemmPipeline::UsePersistentKernel; + else + return false; + }(); + }; + static constexpr bool PersistentKernel = has_persistent_kernel::value; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(AsLayout::size() == AsDataType::size(), + "The size of AsLayout and AsDataType should be the same"); + + static_assert(BsLayout::size() == BsDataType::size(), + "The size of BsLayout and BsDataType should be the same"); + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + + using KernelArgs = + UniversalGemmKernelArgs; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using Kernel = UniversalGemmKernel; + const auto kernel = kentry; + int occupancy; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const UniversalGemmHostArgs& hostArgs) + { + return KernelArgs{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + as_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]); + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]); + } + else if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + }); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + std::array as_k_split_offset; + std::array bs_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + bool AsTesnorIsValid = {true}; + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + }); + + bool BsTesnorIsValid = {true}; + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + }); + + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + using AiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + using BiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (splitk_batch_offset.splitted_k / + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + } + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // TODO: enable vector write for C in ColMajor + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& as_pad_view = generate_tuple( + [&](auto i) { + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& b_flat_pad_view = views.at(I1); + + const auto& bs_pad_view = generate_tuple( + [&](auto i) { + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + if constexpr(GemmPipeline::Preshuffle) + { + // For flatmm, we need to use the flat B tensor view + return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); + } + else + { + return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); + } + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(GemmPipeline::Preshuffle) + { + return make_tile_window( + bs_pad_view[i], + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), + 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }, + number{}); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + template + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + + // Non-persistent kernel entry point + template > + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + + // options + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + // Persistent kernel entry point + template , typename = void> + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto num_tiles = + __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); + auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + + while(block_id < num_work) + { + // Get the tile index for this block + const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + // Get the SplitK offset for this block + const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); + + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + // Run the GEMM + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + // Advance to the next work item + block_id += grid_size; + if(block_id >= num_work) + { + break; + } + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 6861adb153..2bee550b3c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,6 +20,13 @@ struct GemmPipelineAgBgCrImplBase static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; +#if defined(__gfx950__) + static constexpr bool is_a_load_tr = std::is_same_v; + static constexpr bool is_b_load_tr = std::is_same_v; +#else + static constexpr bool is_a_load_tr = false; + static constexpr bool is_b_load_tr = false; +#endif CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -50,11 +57,15 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } - template + template CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, - const SrcTileWindow& lds_tile_window) const + const SrcTileWindow& lds_tile_window, + bool_constant = {}) const { - load_tile(dst_block_tile, lds_tile_window); + if constexpr(LoadTranspose) + dst_block_tile = load_tile_transpose(lds_tile_window); + else + load_tile(dst_block_tile, lds_tile_window); } CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const @@ -96,14 +107,25 @@ struct GemmPipelineAgBgCrImplBase Policy::template MakeADramTileDistribution()); // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_shape = []() { + if constexpr(is_a_load_tr) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}); + auto a_lds_load_tile_distr = []() { + if constexpr(is_a_load_tr) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename ALdsLoadTileDistr::DstrEncode, + typename Problem::ADataType>::TransposedDstrEncode{}); + else + return ALdsLoadTileDistr{}; + }(); auto a_lds_gemm_window = - make_tile_window(a_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - ALdsLoadTileDistr{}); + make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr); return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), @@ -130,14 +152,25 @@ struct GemmPipelineAgBgCrImplBase // TODO: Do we really need those two tile windows??? // They're exactly same... // B LDS tile window for store - auto b_copy_lds_window = make_tile_window( - b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_shape = []() { + if constexpr(is_b_load_tr) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + auto b_lds_load_tile_distr = []() { + if constexpr(is_b_load_tr) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename Problem::BDataType>::TransposedDstrEncode{}); + else + return BLdsLoadTileDistr{}; + }(); auto b_lds_gemm_window = - make_tile_window(b_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - BLdsLoadTileDistr{}); + make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr); return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 6d0db060cd..5b7903a9e7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -153,15 +153,20 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop); static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; using Base::UsePersistentKernel; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", - concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize), - concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), concat('x', kPadM, kPadN, kPadK)); // clang-format on } @@ -467,7 +472,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -478,7 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -494,7 +499,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -506,7 +512,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -517,7 +523,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -536,7 +542,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -578,7 +585,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 8e6bab21be..ac91c2f58f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -141,6 +141,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -305,17 +308,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); - auto a_copy_lds_window0 = make_tile_window( - a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + constexpr auto a_lds_shape = []() { + if constexpr(is_a_load_tr_v()) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0}); + auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0}); - auto a_copy_lds_window1 = make_tile_window( - a_lds_block1, make_tuple(number{}, number{}), {0, 0}); - - auto b_copy_lds_window0 = make_tile_window( - b_lds_block0, make_tuple(number{}, number{}), {0, 0}); - - auto b_copy_lds_window1 = make_tile_window( - b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + constexpr auto b_lds_shape = []() { + if constexpr(is_b_load_tr_v()) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0}); + auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}); // Block GEMM auto block_gemm = BlockGemm(); @@ -325,7 +334,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -336,7 +345,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -354,51 +363,53 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 block_sync_lds(); - constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto ALdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto BLdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + ALdsTile a_block_tile0, a_block_tile1; + BLdsTile b_block_tile0, b_block_tile1; - ALdsTile a_block_tile0; - ALdsTile a_block_tile1; - - BLdsTile b_block_tile0; - BLdsTile b_block_tile1; - + constexpr auto a_lds_input_tile_distr = [&]() { + if constexpr(is_a_load_tr_v()) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(BlockGemm::MakeABlockDistributionEncode()), + typename Problem::ADataType>::TransposedDstrEncode{}); + else + return ALdsTileDistr; + }(); + constexpr auto b_lds_input_tile_distr = [&]() { + if constexpr(is_b_load_tr_v()) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(BlockGemm::MakeBBlockDistributionEncode()), + typename Problem::BDataType>::TransposedDstrEncode{}); + else + return BLdsTileDistr; + }(); auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - static_assert( - !(is_tile_window_linear_v)&&!(is_tile_window_linear_v)&&!( - is_tile_window_linear_v< - decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v), - "LDS windows must not be linear"); + static_assert(!is_tile_window_linear_v && + !is_tile_window_linear_v && + !is_tile_window_linear_v && + !is_tile_window_linear_v, + "LDS windows must not be linear"); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -409,7 +420,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -433,10 +444,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // ping { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -448,7 +459,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 Base::LocalPrefill( a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -473,10 +484,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // pong { block_sync_lds(); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -488,7 +499,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 Base::LocalPrefill( a_copy_lds_window1, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -521,9 +532,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 3 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); - if constexpr(is_a_col_major) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -534,7 +545,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -550,8 +561,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 2 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); block_gemm(c_block_tile, a_block_tile1, b_block_tile1); } // 1 @@ -565,8 +576,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // 2 { block_sync_lds(); - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); block_gemm(c_block_tile, a_block_tile0, b_block_tile0); static_for<0, 8, 1>{}([&](auto i) { ignore = i; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index a42ddd93a0..7d88c804f3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -21,15 +21,27 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { // using AccDataType = float; - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr bool single_load_tr_length = + (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == + (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); + constexpr auto wg_attr_num_access = + ((is_a_load_tr || is_b_load_tr) && !single_load_tr_length) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -272,10 +275,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store @@ -332,7 +335,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -343,7 +346,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -373,12 +376,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -394,7 +398,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -427,12 +431,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -445,7 +450,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -461,14 +466,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } else if constexpr(TailNum == TailNumber::Two) @@ -558,10 +565,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store @@ -617,7 +624,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -628,7 +635,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -658,10 +665,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); // no second block_sync_lds because it's interwave - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -677,7 +688,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -709,10 +720,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto HotLoopTail = [&](auto tail_num) { static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); // no second block_sync_lds because it's interwave - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -725,7 +740,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -741,13 +756,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }); block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); }; if constexpr(TailNum == TailNumber::One) { block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + a_lds_gemm_window, + b_lds_gemm_window, + is_a_load_tr_v, + is_b_load_tr_v); } else if constexpr(TailNum == TailNumber::Two) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 881467cb94..d8118a7f8f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr bool Preshuffle = Problem::Preshuffle; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t kLdsAlignmentInBytes = 16; @@ -282,9 +284,9 @@ struct GemmPipelineAGmemBGmemCRegV1 { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, num_loop, p_smem); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 0f7f6369f0..0560ed9ba9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -394,12 +394,12 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + AccDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy struct UniversalGemmBasePolicy { +#if defined(__gfx950__) + template + static constexpr bool is_a_load_tr = + std::is_same_v, tensor_layout::gemm::ColumnMajor>; + template + static constexpr bool is_b_load_tr = + std::is_same_v, tensor_layout::gemm::RowMajor>; +#else + template + static constexpr bool is_a_load_tr = false; + template + static constexpr bool is_b_load_tr = false; +#endif + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -22,51 +36,65 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ADataType = remove_cvref_t; + using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + if constexpr(is_a_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return a_lds_block_desc_0; + } + else + { + constexpr index_t KPack = GetSmemPackA(); - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - return a_lds_block_desc; + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } } /** @@ -78,14 +106,24 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - // using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; #if 1 - // if constexpr(std::is_same_v) + if constexpr(is_b_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return b_lds_block_desc_0; + } + else + // else if constexpr(std::is_same_v) { constexpr index_t KPack = GetSmemPackB(); constexpr auto BK0 = number{}; @@ -131,10 +169,10 @@ struct UniversalGemmBasePolicy constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); using TileEncodingPattern = TileDistributionEncodingPattern2D; + KPerBlock, + NPerBlock, + VecLoadSize, + BTileAccessPattern>; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -584,17 +622,29 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + Problem::UseStructuredSparsity, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; + static constexpr bool Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -462,7 +462,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp index 6922ddf8a7..25aad329d9 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -430,12 +430,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy>>; #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl>>; #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplF16F16F32M32N32K16, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplF16F16F32M16N16K32, + AttrNumAccess>>; #else +template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) @@ -123,22 +138,29 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl>>; #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) +template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32, + AttrNumAccess>>; #else +template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, - 2>>; + 2, + AttrNumAccess>>; #endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16>>; #else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; +template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; + WarpGemmAtrributeMfma, + AttrNumAccess>>; using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl +// Number of groups of consecutive elements to fill in a ABKLane +enum class WGAttrNumAccessEnum +{ + Single = 1, + Double = 2, + Quad = 4, + Invalid = -1 +}; + +template struct WarpGemmAtrributeMfma { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -25,27 +37,42 @@ struct WarpGemmAtrributeMfma static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; - - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -73,12 +100,16 @@ struct WarpGemmAtrributeMfma } }; -template +template struct WarpGemmAtrributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -104,17 +135,37 @@ struct WarpGemmAtrributeMfmaIterateK { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // each M blocks share the same data return tile_distribution_encoding< sequence, @@ -127,6 +178,8 @@ struct WarpGemmAtrributeMfmaIterateK } else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< sequence<>, @@ -143,17 +196,38 @@ struct WarpGemmAtrributeMfmaIterateK { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< sequence<>, @@ -166,6 +240,8 @@ struct WarpGemmAtrributeMfmaIterateK } else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); // each N blocks share the same data return tile_distribution_encoding< sequence, @@ -289,10 +365,13 @@ struct WarpGemmAtrributeMfmaIterateK } }; -template +template struct WarpGemmAtrributeMfmaTransposedCDistribution { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); using ADataType = typename Impl::BDataType; using BDataType = typename Impl::ADataType; @@ -312,21 +391,35 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; - - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -450,10 +543,13 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB } }; -template +template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; // swap A and B using ADataType = typename Impl::BDataType; @@ -478,80 +574,14 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) - { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // each N blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } + return WarpGemmAtrributeMfmaIterateK:: + get_bwarp_dstr_encoding(); } CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) - { - // each M blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } + return WarpGemmAtrributeMfmaIterateK:: + get_awarp_dstr_encoding(); } CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 80f38f263b..0831cf85c4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); #else ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; @@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { float a_f32 = @@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); #elif defined(__gfx908__) || defined(__gfx90a__) CVecType c_vec{0.f}; static_for<0, 8, 1>{}([&](auto k) { @@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 { #if defined(__gfx94__) or defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { float a_f32 = @@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8 { #if defined(__gfx94__) or defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8 { #if defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8 else { #if defined(__gfx95__) - c_vec = - __builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast(b_vec), c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8( + a_vec, bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index b6ada83532..4e5d102e35 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -16,8 +16,9 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single> struct WarpGemmMfmaDispatcher; // clang-format off @@ -25,12 +26,20 @@ struct WarpGemmMfmaDispatcher; // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; @@ -46,12 +55,20 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; @@ -80,10 +97,18 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; // int8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -102,8 +127,9 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + UseStructuredSparsity, + AttrNumAccess>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index f9d50ed35e..38fd0d408b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -11,9 +11,10 @@ struct WarpGemmImpl { using WarpGemmAttribute = remove_cvref_t; - static constexpr index_t kM = WarpGemmAttribute::kM; - static constexpr index_t kN = WarpGemmAttribute::kN; - static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane; /// @brief The number of elements in K dimension processed by single thread in wavefront. /// /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). diff --git a/include/ck_tile/ops/gemm_group_quant.hpp b/include/ck_tile/ops/gemm_group_quant.hpp new file mode 100644 index 0000000000..9f7565fefb --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" +#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp new file mode 100644 index 0000000000..4c136e78f7 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct BlockGemmQuantBase +{ + using AQDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + static constexpr index_t UnaryOpSize = UnaryOpSize_; + template + CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) + { + float scale_reg_f = 0.f; + if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = ck_tile::bit_cast(scale); + } + else + { + static_assert(false, "AQDataType must be float, fp8_t or bf8_t."); + } + return scale_reg_f; + } + + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +// A is block window on shared memory +// AQ (scale tensor) is block distributed tensor. +// Consecutive kQuantGroupSize elements of A are quantized with a separate scale. +// B is block window on shared memory +// C is block distributed tensor +template +struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consisten with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t QScalesPerBlockRow = + (KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize; + static constexpr index_t QScalesPerWarpGemmRow = + (WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize; + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(kQuantGroupSize % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of kQuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / kQuantGroupSize > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, AQ, B) + // 1. fp8, fp32, fp8 -> f32 + // 2. bf8, fp32, bf8 -> f32 + // 3. i4, (fp8/fp32) fp8 -> f32 + // 4. i4, (fp8/fp32) bf8 -> f32 + static_assert((std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + (std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v) && + std::is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + }; + + public: + using Traits = GemmTraits_; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using Base = BlockGemmQuantBase; + + using WarpGemm = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + + static constexpr auto Scheduler = Traits::Scheduler; + static constexpr uint8_t kA_cvt_scale = std::is_same_v ? 16 : 1; + static constexpr uint8_t kB_cvt_scale = std::is_same_v ? 16 : 1; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static_assert(std::is_same_v); + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using I0 = number<0>; + using I1 = number<1>; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + private: + template + struct BlockGemmImpl + { + }; + + template + struct BlockGemmImpl + { + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + [[maybe_unused]] ASmemBlockWindow& a_block_window, + [[maybe_unused]] BSmemBlockWindow& b_block_window) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + + // hot loop: + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + CWarpTensor c_warp_tensor; + + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); + + // Need to multiply aquant with accumulated C + // + // The accumulated C tile has the standard distribution. For example + // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], + // [26,0], [27,0]. + // + // These elements are in different rows, need to get the scale value + // for the corresponding row. + // Based on aquant's tile distribution, it can be inferred which + // lane holds the relevant scale. For example, the scales corresponding + // to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9, + // 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively. + // + // These scales can be obtained using __builtin_amdgcn_ds_bpermute. + + // MIters per warp + constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; + + // Reg block offset based on mIter + constexpr index_t reg_block_offset = + ((mIter / mIters_per_warp) * Traits::AQPerBlock); + + constexpr index_t lane_base_offset = + (mIter % mIters_per_warp) * WarpGemm::kM; + + // Scale tensor offset along K + constexpr index_t src_reg_offset = reg_block_offset + kQScale; + + constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) { + // Multiply by 4 because output is stored in tiles of 4 + // x CNLane + constexpr uint32_t row_base = + ((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) + + ((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane); + + constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane; + + // Lane index to source scale from + uint32_t src_lane_idx = lane_base_offset + row_base + + (__lane_id() / WarpGemm::kN * kTileRows); + + // Directly index into thread buffer corresponding to + // desired row coefficient + auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // Pull scale data across lanes + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + c_block_tensor + .get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] += + (c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); + }); + }); + }); + }); + } + }; + + public: + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, aq_block_tensor, a_block_window, b_block_window); + } + + private: + BlockGemmImpl block_gemm_impl_{}; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp new file mode 100644 index 0000000000..b1f89fe2e2 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +struct AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmProblem() = default; + CK_TILE_HOST AQuantGemmProblem(index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : M(M_), + N(N_), + K(K_), + QK(QK_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_C(stride_C_), + stride_AQ(stride_AQ_) + { + } + + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; +}; + +struct AQuantGemmHostArgs : public AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmHostArgs() = default; + CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + const void* aq_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_), + a_ptr(a_ptr_), + b_ptr(b_ptr_), + aq_ptr(aq_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t k_batch; +}; + +struct AQuantGemmKernelArgs +{ + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; + index_t k_batch; +}; + +template +struct AQuantGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr AQuantGemmKernelArgs + MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) + { + return AQuantGemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.aq_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.QK, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + + static_assert(std::is_same_v); + if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + const auto& aq_tensor_view = [&]() { + static_assert(std::is_same_v); + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + }(); + + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& aq_pad_view = [&]() { + const auto& aq_tensor_view = views.at(I1); + static_assert(std::is_same_v); + return pad_tensor_view( + aq_tensor_view, + make_tuple(number{}, + number{}), + // TODO: Add support for padding. + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& aq_pad_view = views.at(I1); + const auto& b_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& aq_block_window = [&]() { + static_assert(std::is_same_v); + return make_tile_window( + aq_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param aq_ptr input AQ pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& aq_block_window = gemm_tile_windows.at(I1); + const auto& b_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + assert(kargs.k_batch == 1); + RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..1356d7e222 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + + using AQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize; + + static_assert(KPerBlock % QuantGroupSize == 0, + "KPerBlock must be a multiple of QuantGroupSize"); + + // Create DRAM tile window for AQ + template + CK_TILE_DEVICE constexpr auto + GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + aq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeAQDramTileDistribution()); + return aq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..2004f7d90e --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::ATileAccessPattern; + using Base::BTileAccessPattern; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQLayout = remove_cvref_t; + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + + static_assert(std::is_same_v); + return GetAQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + using AQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + static_assert(std::is_same_v); + using TileEncodingPattern = TileDistributionEncodingPatternAQ; + + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of kQuantGroupSize!"); + + using WarpGemm = WarpGemmMfmaDispatcher; + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return AQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..746396b13a --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t AQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeAQ() + { + return Policy::template GetVectorSizeAQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "aquant_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), + concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t AQ_Buffer_Load_Inst_Num = + MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "AQ vector size: " << GetVectorSizeAQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Aq block window has incorrect lengths for defined AqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + auto block_gemm = BlockGemm(); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + AQBlockTile aq_block_tile[2]; + int currIdx = 0; + + auto c_block_tile = block_gemm.MakeCBlockTile(); + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // DRAM prefetch (global read 0) + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + + currIdx = (currIdx + 1) % 2; + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + else + { + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + currIdx = (currIdx + 1) % 2; + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + return c_block_tile; + } + }; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp new file mode 100644 index 0000000000..4cca30fd3b --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" + +#include + +namespace ck_tile { + +template +struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase +{ + using Base = GemmPipelineProblemBase; + + using Traits = typename Base::Traits; + + using typename Base::ADataType; + using typename Base::BDataType; + using typename Base::CDataType; + using typename Base::ComputeDataType; + using AQDataType = remove_cvref_t; + + using BlockGemmShape = typename Base::BlockGemmShape; + + using typename Base::ALayout; + using typename Base::BLayout; + using typename Base::CLayout; + + static constexpr bool TransposeC = false; + + using Base::kBlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::DoubleSmemBuffer; + using Base::VectorLoadSize; + + using AQLayout = remove_cvref_t; + + static constexpr uint32_t kQuantGroupSize = QuantGroupSize_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + static_assert(BlockGemmShape::kK % kQuantGroupSize == 0); + static_assert(Scheduler == GemmPipelineScheduler::Intrawave); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_aquant_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler, + "QuantGroupSize", + kQuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ() + { + static_assert(std::is_same_v); + return VectorLoadSize / sizeof(AQDataType); + } + + static constexpr index_t VectorSizeAQ = []() { + static_assert(std::is_same_v); + return kPadK ? 1 : GetAlignmentAQ(); + }(); +}; + +template +using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp new file mode 100644 index 0000000000..c018314ab7 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize() +{ + using I1 = number<1>; + constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t BlockSize = Problem::kBlockSize; + + // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps + constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + + // Define vector load candidates in descending order of priority + constexpr std::array candidates{ + PackedSize * 32 / sizeof(DataType), + PackedSize * 16 / sizeof(DataType), + PackedSize * 8 / sizeof(DataType), + PackedSize * 4 / sizeof(DataType), + PackedSize * 2 / sizeof(DataType), + }; + + for(const auto vec_size : candidates) + { + if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0) + continue; + bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) && + (elements_per_thread % vec_size == 0) && vec_size != candidates[4]; + if(is_valid) + { + return vec_size; + } + } + return PackedSize; // Absolute fallback +} + +// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across +// threads. Post mfma scales are shuffled across threads in the warp and applied to +// accum registers. +template +struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +{ + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM); + + static_assert(num_warps == MWarps * NWarps * KWarps); + + // KWarps > 1 isn't supported + static_assert(KWarps == 1); + + // # of elements per thread + static constexpr index_t X = XPerTile; + + static constexpr index_t Y0 = 1; + static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; + static constexpr index_t Y2 = MWarps; + static constexpr index_t Y3 = WarpGemm::kM; + static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, + "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 1>>, + tuple, sequence<0, 3>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp new file mode 100644 index 0000000000..4972badb3f --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmAQuantTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + using AQLayout = AQLayout_; + + static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index ae5720776c..29332f941a 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -3,9 +3,11 @@ #pragma once +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp new file mode 100644 index 0000000000..115f6dea19 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -0,0 +1,862 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvBwdWeightKernelArgs +{ + + using ConvToGemmTransformer = + TransformConvBwdWeightToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.C_; // B: In NWGC + group_stride_c = args.K_ * args.C_ * // C: Wei GKXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NHWGK + group_stride_b = args.C_; // B: In NHWGC + group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NDHWGK + group_stride_b = args.C_; // B: In NDHWGC + group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + using ABCGridDescs = + remove_cvref_t; + + using AGridDescMK = remove_cvref_t{}])>; + using BGridDescNK = remove_cvref_t{}])>; + using CGridDescMN = remove_cvref_t{}])>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + + index_t k_batch; + index_t GemmM; + index_t GemmN; + index_t GemmK; + index_t GemmBatch; + + const void* out_ptr; + const void* in_ptr; + std::array ds_ptr; + void* wei_ptr; + + AGridDescMK a_grid_desc_m_k; + BGridDescNK b_grid_desc_n_k; + CGridDescMN c_grid_desc_m_n; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Forward kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution forward kernel template. By semantic +/// division of Implicit GEMM algorithm into following parts we achieve flexible, +/// versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// tparam ConvSpecialization Tensor descriptors specialization. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionBackwardWeightKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using OutDataType = remove_cvref_t; + + using GroupedConvBwdWeightKernelArgsSpecialized = + GroupedConvBwdWeightKernelArgs; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = true; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_backward_weight", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto + GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + return dim3( + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized + MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs) + { + return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = + __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1); + + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = + __builtin_amdgcn_readfirstlane(kargs.GemmK - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const stream_config& s) + { + return [&]() { + if(kargs.k_batch > 1) + hipGetErrorString(hipMemsetAsync(kargs.wei_ptr, + 0, + kargs.GemmBatch * kargs.GemmM * kargs.GemmN * + sizeof(WeiDataType), + s.stream_id_)); + }; + } + + CK_TILE_HOST static bool + IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view(a_ptr, + kargs.a_grid_desc_m_k); // A: out + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view(b_ptr, + kargs.b_grid_desc_n_k); // B: in + }(); + + const auto& c_tensor_view = [&]() { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.GemmM, kargs.GemmN), + make_tuple(kargs.GemmN, 1), + number{}, + number<1>{}); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_desc_m_n); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{} * k_batch), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{} * k_batch), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, + const index_t i_m, + const index_t i_n, + const index_t i_k) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_k}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, i_k}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t block_idx_k) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t block_idx_k) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock)); + const index_t i_k = + __builtin_amdgcn_readfirstlane(blockIdZ * num_loop * TilePartitioner::KPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + // conv_bwd_weight = Out * In = Weight + const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; + const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; + WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + num_loop, + i_m, + i_n, + i_k); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm( + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 196c468c07..8cd1710043 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -34,16 +34,16 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -56,9 +56,10 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0]; - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0]; + GemmM = args.N_ * args.output_spatial_lengths_[0]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -103,18 +104,18 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -122,19 +123,20 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -179,20 +181,20 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -201,17 +203,17 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -220,6 +222,7 @@ struct GroupedConvFwdKernelArgs GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * args.filter_spatial_lengths_[2]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -256,15 +259,15 @@ struct GroupedConvFwdKernelArgs group_stride_c = args.K_; } - using AGridDescMK = remove_cvref_t())>; - using BGridDescNK = remove_cvref_t())>; - using CGridDescMN = remove_cvref_t())>; + using AGridDescMK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeADescriptor_M_K())>; + using BGridDescNK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeBDescriptor_N_K())>; + using CGridDescMN = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; array in_g_n_c_wis_lengths; @@ -280,6 +283,7 @@ struct GroupedConvFwdKernelArgs index_t GemmM; index_t GemmN; index_t GemmK; + index_t GemmBatch; const void* in_ptr; const void* wei_ptr; @@ -354,8 +358,7 @@ struct GroupedConvolutionForwardKernel using OutLayout = remove_cvref_t; using DsLayout = remove_cvref_t; - using GemmDsLayout = remove_cvref_t; - + using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; @@ -389,20 +392,16 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args) + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { - const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - const index_t GemmN = args.K_; - return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch); + return dim3( + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized - MakeKernelArgs(const GroupedConvHostArgs& hostArgs) + MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) { return GroupedConvFwdKernelArgsSpecialized(hostArgs); } @@ -750,7 +749,7 @@ struct GroupedConvolutionForwardKernel auto& c_block_window = gemm_tile_windows.at(I3); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1); + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 4b7cb3c895..b173ab25a1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -14,14 +14,15 @@ namespace ck_tile { /// This structure is passed to Grouped Convolution Kernels when creating kernel /// arguments object. It contain all necessary information required to /// build proper kernel argument and launch kernel on GPU. +template struct GroupedConvHostArgs : public conv::ConvParam { CK_TILE_HOST GroupedConvHostArgs() = delete; CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, - const void* in_ptr_, - const void* wei_ptr_, + InPtr in_ptr_, + WeiPtr wei_ptr_, const std::vector ds_ptr_, - void* out_ptr_, + OutPtr out_ptr_, index_t k_batch_) : conv::ConvParam(conv_param), in_ptr(in_ptr_), @@ -32,13 +33,16 @@ struct GroupedConvHostArgs : public conv::ConvParam { } - const void* in_ptr; - const void* wei_ptr; + InPtr in_ptr; + WeiPtr wei_ptr; const std::vector ds_ptr; - void* out_ptr; + OutPtr out_ptr; index_t k_batch; }; +using GroupedConvFwdHostArgs = GroupedConvHostArgs; +using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs; + template ; + true, + true, + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor, + ck_tile::tensor_layout::gemm::RowMajor>; static constexpr index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp new file mode 100644 index 0000000000..b2b7918810 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -0,0 +1,659 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" + +namespace ck_tile { + +template +struct TransformConvBwdWeightToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvBwdWeightToGemm() {} + + template + CK_TILE_HOST TransformConvBwdWeightToGemm( + const TransformConvBwdWeightToGemmBase& transform_conv_fwd_to_gemm_base) + : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, + N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NWGC + const index_t NStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKXC + const index_t KStride = X_ * C_; + constexpr auto CXStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NHWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NHWGC + const index_t NStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKYXC + const index_t KStride = Y_ * X_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_), + make_tuple(KStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NDHWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_; + const index_t DiStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // KZYXC + const index_t KStride = Z_ * Y_ * X_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_), + make_tuple(KStride, CStride)); + } + + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X_, C_)), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index c93329bfbe..434be9f84a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -380,6 +380,6 @@ struct BlockReduce2D // deduction guide template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D; +CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 6a1f926a9a..62c9944bd2 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync } }; +template +struct BlockReduce2dTreeCrossWarpSync +{ + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + + template + CK_TILE_DEVICE static constexpr index_t GetReduceWarps() + { + constexpr index_t num_reduce_warps = [&]() { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_warp = 0; + + index_t len_ = 1; + static_for<0, NDimR, 1>{}([&](auto idim_r) { + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + len_ *= r_length; + } + }); + return len_; + }(); + return num_reduce_warps; + } + + // return in byte + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + using DataType = typename YDistributedTensor_::DataType; + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + // we need to store all data from every wave into smem + // e.g. 2x2 reduce along N + // -------------> reduce N + // | w0 | w1 | ___> | w01 | + // | w2 | w3 | | w23 | + // + // -> store data from every wave into LDS + // + // + // -------------> reduce N + // | w0 | w1 | w2 | w3 | -----> | w0123 | + // + // -> also store data from every wave into LDS + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + return num_warps * thread_buf_size * sizeof(DataType); + } + + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + using DataType = typename YDistributedTensor_::DataType; + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + DataType* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + constexpr index_t num_reduce_warps = GetReduceWarps(); + + if constexpr(num_reduce_warps == 1) + return; + + // Each warp's lane 0 writes its partial results to shared memory + const index_t smem_offset = warp_id; + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + // Store the i-th element of this warp's thread_buffer into SMEM + smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + }); + } + block_sync_lds(); + + // We let each warp holds a duplication to do reduction. + static_for<0, thread_buf_size, 1>{}([&](auto i) { + DataType v = 0; + if(lane_id < num_reduce_warps) + { + v = smem_ptr[lane_id + i * num_warps]; + } + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // pull data from remote lane + const auto o = + __shfl_xor(v, number{}.value); + + // reduce + v = reduce_func(v, o); + }); + } + }); + + y_tensor.get_thread_buffer()(i) = v; + }); + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 3eec2a1ab6..610541b2e4 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index f0251177d4..6cb81b8856 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -58,13 +58,14 @@ struct Rmsnorm2dFwd static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; - static constexpr index_t Block_M = Problem::BlockShape::Block_M; - static constexpr index_t Block_N = Problem::BlockShape::Block_N; - static constexpr bool kPadM = false; // always no need to pad along M - static constexpr bool kPadN = Problem::Traits::kPadN; - static constexpr bool kTwoPass = Problem::Traits::kTwoPass; - static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; - static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kTwoPass = Problem::Traits::kTwoPass; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; @@ -150,6 +151,8 @@ struct Rmsnorm2dFwd if (kPadN) n += "_pn"; if (kSaveInvRms) n += "_rms"; if (kTwoPass) n += "_2p"; + if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm"; + else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml"; return n; }(); auto prec_str = [&] () { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index 356a2e12ca..df689c6b46 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy return BlockReduce2dCrossWarpSync{}; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dTreeCrossWarpSync{}; + } + template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp new file mode 100644 index 0000000000..810c3c5243 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +/** + * @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant + * based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like + * method. + * + * The T5 model, developed by Google, is a transformer-based architecture designed to perform + * a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS + * normalization is handled, particularly where intermediate values are cast to BF16. This aims to + * achieve a similar value distribution to that produced by the VLLM hip implementation, thereby + * enhancing model accuracy. + * + * Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is + * not guaranteed to eliminate all differences or ensure uniform outcomes across every use case. + * + * This implementation is a variant based on the original one-pass and two-pass approaches, + * allowing for both fused and non-fused add operations. + */ + +template +struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, + const GammaWindow& gamma_window_, + YWindow& y_window_, + const YResidualWindow& y_residual_window_, + InvRmsWindow& inv_rms_window, + const SmoothScaleWindow& sm_scale_window_, + YScaleWindow& y_scale_window_, + UnquantYWindow& unquant_y_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem, + Epilogue) const + { + const auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + const auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_tree_cross_warp_sync = + Policy::template GetBlockReduce2dTreeCrossWarpSync(); + + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + + // load gamma (TODO: support no gamma?) + const auto gamma = load_tile(gamma_window); + + auto acc = cast_tile(x); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD || + kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + [[maybe_unused]] auto pre_out = + make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + + // To make norm input align with residual output + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + if constexpr(std::is_same_v) + { + pre_out(idx) = float_to_bf16(acc(idx)); + } + else + { + pre_out(idx) = type_convert(acc(idx)); + } + acc(idx) = type_convert(pre_out(idx)); + } + }); + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + store_tile(y_residual_window, pre_out); + } + } + + // compute mean square each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, 0); + if constexpr(Problem::BlockShape::Vector_N % 2 == 0) + { + sweep_tile( + acc, + [&](auto idx_0, auto idx_1) { + square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1]; + }, + sequence<1, 2>{}); + } + else + { + square_sum = block_reduce2d(acc, + reduce_square_sum_func.GetIdentityValue(), + reduce_square_sum_func); + } + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // rmsnorm computation + auto rmsn = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + if constexpr(std::is_same_v) + { + const auto tmp0 = + float_to_bf16(acc[idx] * inv_rms_[i_idx]); + const auto tmp1 = float_to_bf16( + type_convert(tmp0) * gamma_); + const auto rmsn_ = type_convert(tmp1); + rmsn(idx) = rmsn_; + } + else + { + const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); + const auto rmsn_ = type_convert(tmp) * gamma_; + rmsn(idx) = rmsn_; + } + }); + + if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + if constexpr(kSaveUnquant) + { + Epilogue{}( + unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + } + else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) + { + if constexpr(kSaveUnquant) + { + Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + } + } + else + { + Epilogue{}(y_window_, rmsn); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 58159142d0..c77d61872e 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass // compute inv-rms auto inv_rms = tile_elementwise_in( - [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); - }, - square_sum); + [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); if constexpr(kSaveInvRms) store_tile(inv_rms_window, cast_tile(inv_rms)); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp index 152da60c01..b91f17ffdd 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp @@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName struct Rmsnorm2dFusedQuantEnumName { static constexpr const char * name = "smdqt"; }; // clang-format on +enum class Rmsnorm2dSensitiveEnum +{ + NO_SPECIFIC_MODEL = 0, + // T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based + // architecture designed for a variety of NLP tasks. This option mimics T5's approach to + // RMSNorm, aiming to ensure similar value distributions and enhance accuracy. + T5_MODEL_LIKE = 1, +}; + +// clang-format off +template struct Rmsnorm2dSensitiveEnumName; +template<> struct Rmsnorm2dSensitiveEnumName { static constexpr const char * name = "nsm"; }; +template<> struct Rmsnorm2dSensitiveEnumName { static constexpr const char * name = "t5ml"; }; +// clang-format on + template + Rmsnorm2dFusedQuantEnum kFusedQuant_, + Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_> struct Rmsnorm2dFwdTraits { - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kSaveUnquant = kSaveUnquant_; - static constexpr bool kTwoPass = kTwoPass_; - static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; - static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; + static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; + static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_; }; } // namespace ck_tile diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 98ceab6992..172fcee2e3 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -695,18 +695,18 @@ struct naive_attention_fwd_kernel static_cast(variation_), \ static_cast(quant_algo_)>; \ using k_ = naive_attention_fwd_kernel; \ + k_type_, \ + v_type_, \ + o_type_, \ + acc_type_, \ + kvscale_type_, \ + q_layout_, \ + k_layout_, \ + v_layout_, \ + o_layout_, \ + k_scale_layout_, \ + v_scale_layout_, \ + ktraits_>; \ dim3 grids = k_::get_grid_size(a); \ r = ck_tile::launch_kernel(s, \ ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \ diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index 9f2ef3389f..6f5a425207 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,14 +1,8 @@ -from datetime import datetime -import pathlib -from pathlib import Path -import subprocess -import os -import copy +from datetime import datetime import pathlib from pathlib import Path import subprocess import os + import copy -NS = 'ck_tile' -OPS = 'ops' -REF = 'ref' -OPS_COMMON = 'common' # common header will be duplicated into ops/* other module + NS = 'ck_tile' OPS = 'ops' REF = 'ref' OPS_COMMON = + 'common' #common header will be duplicated into ops/* other module HEADER_COMMON = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n @@ -82,7 +76,7 @@ submodule = submodule_t() # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 120bf7484a..59dfd76ede 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -116,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp index eedd687bde..9f04cf3e3d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp @@ -110,7 +110,7 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 2c2cac77e3..28274a5154 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -25,17 +25,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - index_t m, - index_t n, - index_t k, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + index_t m, + index_t n, + index_t k, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp index 681f466677..2f0c6113de 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_column_to_image_bf16_instances = std::tuple< - // clang-format off +using device_column_to_image_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -39,12 +40,13 @@ using device_column_to_image_bf16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f16_instances = std::tuple< - // clang-format off +using device_column_to_image_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -59,12 +61,13 @@ using device_column_to_image_f16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f32_instances = std::tuple< - // clang-format off +using device_column_to_image_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -76,12 +79,13 @@ using device_column_to_image_f32_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_i8_instances = std::tuple< - // clang-format off +using device_column_to_image_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -97,8 +101,8 @@ using device_column_to_image_i8_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8>, DeviceColumnToImageImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp index 74a2155a04..2d2798b667 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_image_to_column_bf16_instances = std::tuple< - // clang-format off +using device_image_to_column_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -38,12 +39,13 @@ using device_image_to_column_bf16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f16_instances = std::tuple< - // clang-format off +using device_image_to_column_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -58,12 +60,13 @@ using device_image_to_column_f16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f32_instances = std::tuple< - // clang-format off +using device_image_to_column_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -75,12 +78,13 @@ using device_image_to_column_f32_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_i8_instances = std::tuple< - // clang-format off +using device_image_to_column_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -96,8 +100,8 @@ using device_image_to_column_i8_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8>, DeviceImageToColumnImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp index 93eed31bc5..6543e3df23 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,22 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); +#endif +#ifdef CK_USE_XDL void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( std::vector>>& instances); #endif +#endif template && is_same_v && is_same_v) { +#ifdef CK_USE_WMMA + add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif +#ifdef CK_USE_XDL add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 5c0d7283f2..11a8ff8e91 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -112,6 +112,36 @@ using device_grouped_conv_bwd_data_xdl_f16_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f16_nchw_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 8, 1, 32>, 2> + // clang-format on + >; // bf16_bf16_f32_bf16 template -using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -56,8 +57,8 @@ using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f32_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_bf16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp index 40c4d558b8..47cb9a88a4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -37,9 +37,8 @@ template -using device_grouped_conv_bwd_weight_wmma_f16_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -71,17 +70,16 @@ using device_grouped_conv_bwd_weight_wmma_f16_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_wmma_i8_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -110,8 +108,8 @@ using device_grouped_conv_bwd_weight_wmma_i8_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index d1466206f0..90e8dc0221 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -3,6 +3,7 @@ function(add_instance_library INSTANCE_NAME) set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) + get_filename_component(source_name ${source} NAME) set(test 0) foreach(type IN LISTS DTYPES) if(type MATCHES "fp16") @@ -19,13 +20,13 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_i8") endif() #make an exception for reduction kernels - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column") + if("${source_name}" MATCHES "${type}" OR "${source_name}" MATCHES "${type1}" OR "${source_name}" MATCHES "device_reduce_instance" OR ${source_name} MATCHES "device_image_to_column") #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) endif() @@ -39,66 +40,52 @@ function(add_instance_library INSTANCE_NAME) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) - # Do not build DPP instances if DPP_KERNELS macro is not set foreach(source IN LISTS ARGN) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + get_filename_component(source_name ${source} NAME) + + # Do not build DPP instances if DPP_KERNELS macro is not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") message(DEBUG "removing dpp instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build DL instances if DL_KERNELS macro is not set - foreach(source IN LISTS ARGN) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + # Do not build DL instances if DL_KERNELS macro is not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build XDL instances if gfx9 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + # Do not build XDL instances if gfx9 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build MX instances if gfx950 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + # Do not build MX instances if gfx950 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") message(DEBUG "removing MX instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build WMMA instances if gfx11 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + # Do not build WMMA instances if gfx11 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build mha instances if gfx94 or gfx90a targets are not on the target list - foreach(source IN LISTS ARGN) - if((NOT BUILD_MHA_LIB OR (NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND NOT INST_TARGETS MATCHES "gfx95")) AND source MATCHES "mha") - message(DEBUG "removing mha instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 - if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") + # Do not build mha instances if gfx94 or gfx90a targets are not on the target list + if((NOT BUILD_MHA_LIB OR (NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND NOT INST_TARGETS MATCHES "gfx95")) AND source_name MATCHES "mha") + message(DEBUG "removing mha instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - endif() - # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "gemm_wmma_universal" AND source MATCHES "_f8_") + endif() + # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ + if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -109,41 +96,43 @@ function(add_instance_library INSTANCE_NAME) if(ARGN) set(INST_OBJ) foreach(source IN LISTS ARGN) + get_filename_component(source_name ${source} NAME) + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) - if(source MATCHES "_xdl") + if(source_name MATCHES "_xdl") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(source MATCHES "_wmma") + elseif(source_name MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(source MATCHES "mha") + elseif(source_name MATCHES "mha") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "_mx") + if(source_name MATCHES "_mx") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") + if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() else() - if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") + if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() endif() - if(source MATCHES "gemm_wmma_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "f8") list(FILTER INST_TARGETS INCLUDE REGEX "gfx12") endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) - string(APPEND offload_targets "--offload-arch=${target} ") + string(APPEND offload_targets "--offload-arch=${target} ") endforeach() set_source_files_properties(${source} PROPERTIES COMPILE_FLAGS ${offload_targets}) list(APPEND INST_OBJ ${source}) @@ -165,7 +154,7 @@ function(add_instance_library INSTANCE_NAME) list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS}) endif() - + target_compile_features(${INSTANCE_NAME} PUBLIC) # flags to compress the library diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index 659d6a99a9..34b580cf75 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -31,9 +31,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( std::vector; // FIXME: retire dedicated 2D version -using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = - std::tuple< - // clang-format off +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -122,8 +121,8 @@ using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instan DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( std::vector -using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -51,8 +50,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; // double rate mfma instances on gfx950 -using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 256, 64, 64, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp index 5a0c52c2df..ab5f40e81d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp index 59ffb80bd4..6f368a44d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp index a64424e8ac..7049732e41 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp index a0dd60c0f5..eef7e728d2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -25,8 +24,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp index 122fff4960..e966b3ec49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp index 9f459aabfc..e090b157b3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp index 3671bea7a3..811358a3d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp index 98db8bad1c..a9ee03ca49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -34,8 +33,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp index 532c348b7e..d4e5ab8014 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp index b931b8fdfd..03fdf13bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp index fa53a3bf0f..c3ab756f3b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index a590413acc..aa895fc0cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -28,9 +28,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -43,8 +42,8 @@ using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( std::vector -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< - // clang-format off +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = + std::tuple< + // clang-format off //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -46,8 +47,8 @@ using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = st DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> - // clang-format on - >; + // clang-format on + >; template using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 424320fa8f..34f51f5f58 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -1,10 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_B_SCALE_INSTANCES) list(APPEND GEMM_B_SCALE_INSTANCES device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp ) set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) \ No newline at end of file +add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..9476eb6bf0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB| + //################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | | + //################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8 + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20 + + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..9c196a3c58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp index ce5cf21a85..1f8ca4d23a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -46,7 +46,7 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - + //Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp index 40bacb3ee9..97357f1ee4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp @@ -46,10 +46,11 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple< //#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, @@ -65,6 +66,14 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple< DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + //new instances for testing + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 8, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 4, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index 430daae3ab..06d6780227 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index 9b876f5430..fd938f502f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 65261235b6..87300fa871 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp index dc770d8d9a..902e349492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 266e6b1a5d..a439cf27f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 1674b2de6c..55e0362018 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 758420ca37..e51de0556c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index dad402dff4..722a0bae55 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index ee15dfa94e..d10b9facd5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -50,8 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index 93039a5008..d9d16ede65 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 1dc9678c5b..9277e5e901 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e4682c27d3..e97a649c19 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index 0c601b3823..c8f1b85ddb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index 8d11b6f9d9..fc0220a502 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index d389da5ee8..b87cf64b0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 001330eabb..31ad66409e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp index 59154f3439..a6b6465128 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -53,8 +54,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index b962d75b12..e0bbe7dff0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -55,8 +56,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 9f142ad831..5cb767ab0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 7d141a47e1..ac29d1ba9c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp index 8d109d1346..1a8227279d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp @@ -53,9 +53,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_comp_instances = std::tupl #endif template -using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -79,8 +78,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index 940da94e70..a160f84175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index d83014d5e8..2f043cef03 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -63,8 +64,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index ff13de1d6a..0d72da9e6e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index bb10da37f4..c763b5048c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -45,8 +46,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp index 680788d668..63300d2c37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp @@ -53,8 +53,9 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple< #endif template -using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -78,8 +79,8 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp index 5c525244e1..783606ef9d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -50,8 +51,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< // We prefer following instance, however, existing compiler bug cause it failed to generate sanity code. // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp index af4008c91d..bece6b4c30 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -47,8 +48,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp index 27d7933477..da4307d9be 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -54,6 +54,54 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< #endif // clang-format on >; +// instances for double rate mfma on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 32, 32, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8> +#endif + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< @@ -115,6 +163,42 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< #endif // clang-format on >; +// instances for double rate mfma on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 256, 32, 32, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 256, 32, 32, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 256, 32, 32, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 256, 32, 32, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 256, 32, 32, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 256, 32, 32, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 256, 32, 32, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 256, 32, 32, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 256, 32, 32, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> +#endif + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp index d6c9809020..6cf0228c04 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -17,7 +17,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); - if(ck::get_device_name() != "gfx950") + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr{}); + } + else { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp index fc6ad01742..65e49d5f88 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -17,7 +17,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); - if(ck::get_device_name() != "gfx950") + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr{}); + } + else { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp index f6a9c48555..56c7c71a13 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp index f9c12e7cb2..bad30bad99 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp index 1d33c7fa57..8d6b8dcbca 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp index 252aec5bc2..d0bbc4aeda 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp index b4554fc6a9..f03dc4fc8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp index b6a60a1f31..7f1976f220 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -59,8 +60,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp index 5353fe16b5..93ac0d7dcc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 959c1c0992..b2e3252e4d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -61,8 +62,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp index 282cea7563..a318627bea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp index 7335a9851f..92e5c86343 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -62,8 +63,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp index d03002af5c..f83b0a47c9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp index 7736f38cb2..2de3ed35b0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp index 57b6ab3ae2..a38eef7294 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -52,8 +53,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp index 14bd36d29f..d2e15f01da 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index 3f94d30a55..320d637a07 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -32,6 +32,14 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( Empty_Tuple, NGCHW, ConvBwdDataDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_nchw_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp index bada2507c2..b1043260ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -32,6 +32,14 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( Empty_Tuple, NGCDHW, ConvBwdDataDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_nchw_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp index 839d3559f7..2344108576 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -80,9 +80,8 @@ template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -99,8 +98,8 @@ using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index a71f8a4fa1..634b7f0890 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -215,9 +215,8 @@ ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, ch std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) { - os << "ConvParam {" - << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_ - << "\nK: " << p.K_ << "\nC: " << p.C_ + os << "ConvParam {" << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ + << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ << "\nconv_filter_strides: " << p.conv_filter_strides_ diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index b70dd9538d..5ea1a78094 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -260,9 +260,9 @@ bool profile_conv_bwd_data_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index 917e4c07fc..37366821c4 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -233,9 +233,9 @@ bool profile_conv_fwd_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp index fa0a771962..14182bb7b0 100644 --- a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp +++ b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp @@ -288,9 +288,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\nGB/s: " << best_gb_per_sec << std::endl; return is_supporting_instance && pass; } diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index fe977e766e..86370e2f47 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -173,7 +173,7 @@ bool profile_gemm_b_scale_impl(int do_verification, } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm #include #include +#if defined(__unix__) #include +#endif #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -213,7 +215,9 @@ int profile_gemm_impl(int do_verification, instance_id++; } +#if defined(__unix__) sleep(2); +#endif // Run the best instance again { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 12f6ad606f..0aeefaabfb 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -287,10 +287,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " - << best_split_k << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index c1bb90dd9c..84acb53425 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -92,12 +92,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, if(do_verification) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight{}; + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>{}; auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, weight_host_result, @@ -302,10 +302,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " - << best_split_k << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; return all_pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index c12fa75e34..d0e1cf2611 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -178,8 +178,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, in_element_op, wei_element_op, out_element_op, - {}, - {}, + {}, + {}, d_tensors); // init host output to zero @@ -312,9 +312,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index a1f9ee1528..2dcee4c1fc 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -250,9 +250,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index bd756eb825..b553e07735 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -342,9 +342,9 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_softmax_impl.hpp b/profiler/include/profiler/profile_softmax_impl.hpp index daaf565149..83913d8398 100644 --- a/profiler/include/profiler/profile_softmax_impl.hpp +++ b/profiler/include/profiler/profile_softmax_impl.hpp @@ -103,12 +103,12 @@ bool profile_softmax_impl(int do_verification, // add device softmax instances using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceOp = tensor_operation::device::DeviceSoftmax; + AccDataType, + OutDataType, + PassThrough, + PassThrough, + Rank, + NumReduceDim>; // get device op instances const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -141,8 +141,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]"; + << "], " << "scaler = [" << alpha << ", " << beta << "]"; LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl; instance_pass.push_back(true); continue; @@ -202,8 +201,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " failed verification: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + << "], " << "scaler = [" << alpha << ", " << beta << "]." << std::endl; } instance_pass.push_back(pass); } @@ -215,9 +213,8 @@ bool profile_softmax_impl(int do_verification, LogRange(std::cout << "length = ", in_tensor_lengths, ",") << ", "; LogRange(std::cout << "stride = ", in_tensor_strides, ",") << ", "; LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; - std::cout << "alpha = " << alpha << ", " - << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec - << " GB/s, " << best_instance_name << std::endl; + std::cout << "alpha = " << alpha << ", " << "beta = " << beta << ", " << best_avg_time + << " ms, " << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; } return std::all_of( std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; }); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 1dc942699f..4700a34e9d 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -72,7 +72,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) - list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) @@ -93,7 +92,10 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) + list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() @@ -178,7 +180,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) endif() list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) @@ -198,6 +199,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR @@ -208,6 +213,7 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) + list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_batched_gemm_b_scale.cpp b/profiler/src/profile_batched_gemm_b_scale.cpp index f768a17570..5fe6f490be 100644 --- a/profiler/src/profile_batched_gemm_b_scale.cpp +++ b/profiler/src/profile_batched_gemm_b_scale.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "profiler/profile_batched_gemm_b_scale_impl.hpp" #include "profiler_operation_registry.hpp" @@ -114,7 +115,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[]) n_iter = std::stoi(argv[18]); rotating = std::stoull(argv[19]) * 1024 * 1024; - printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating); + printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating); } using F32 = float; diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 990e1e1196..a64555fc66 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index 85252eaa37..a168c09bcf 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/profiler/src/profile_gemm_b_scale.cpp b/profiler/src/profile_gemm_b_scale.cpp index 443ebff834..7bcc96a434 100644 --- a/profiler/src/profile_gemm_b_scale.cpp +++ b/profiler/src/profile_gemm_b_scale.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "profiler/profile_gemm_b_scale_impl.hpp" #include "profiler_operation_registry.hpp" @@ -100,7 +101,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) n_iter = std::stoi(argv[17]); rotating = std::stoull(argv[18]) * 1024 * 1024; - printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating); + printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating); } using F32 = float; diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp new file mode 100644 index 0000000000..34b3df1c65 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bias_clamp" +#define OP_DESC "Grouped Convolution Forward+Bias+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_clamp); diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp new file mode 100644 index 0000000000..600f91744a --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_clamp" +#define OP_DESC "Grouped Convolution Forward+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = + ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_clamp); diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 728b8c1092..53de05a7d8 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 4d0836af39..c45bb4330d 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -5,19 +5,33 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 -if [ $# -ge 2 ] ; then - GPU_TARGETS=$2 - shift 2 - REST_ARGS=$@ + +if [ $# -ge 2 ]; then + case "$2" in + gfx*) + GPU_TARGETS=$2 + shift 2 + echo "GPU targets provided: $GPU_TARGETS" + REST_ARGS=$@ + ;; + *) + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" + GPU_TARGETS="gfx908;gfx90a;gfx942" + shift 1 + REST_ARGS=$@ + ;; + esac else + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" GPU_TARGETS="gfx908;gfx90a;gfx942" - REST_ARGS= + shift 1 + REST_ARGS=$@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ --D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-std=c++20 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index acb04ac75f..311ea91822 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -5,13 +5,16 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 -if [ $# -ge 2 ] ; then +if [ $# -ge 2 ] && [[ "$2" =~ ^gfx ]]; then GPU_TARGETS=$2 shift 2 + echo "GPU targets provided: $GPU_TARGETS" REST_ARGS=$@ else + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" GPU_TARGETS="gfx908;gfx90a;gfx942" - REST_ARGS= + shift 1 + REST_ARGS=$@ fi cmake \ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1be7c88c2e..c6c09eb6ca 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -37,6 +37,9 @@ set(REGRESSION_TESTS test_grouped_convnd_bwd_data_xdl test_conv_tensor_rearrange test_gemm_mx + test_ck_tile_batched_transpose_fp8 + test_ck_tile_batched_transpose_fp16 + test_ck_tile_batched_transpose_bf16 ) function(add_test_executable TEST_NAME) @@ -239,6 +242,7 @@ add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) +add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) diff --git a/test/block_swizzle_test/rebuild.sh b/test/block_swizzle_test/rebuild.sh index b07eb55048..553d1900d4 100644 --- a/test/block_swizzle_test/rebuild.sh +++ b/test/block_swizzle_test/rebuild.sh @@ -1,3 +1,3 @@ CC=g++ -$CC -Wall -std=c++17 -Iinclude -O3 block_swizzle_test.cpp -o block_swizzle_test.exe \ No newline at end of file +$CC -Wall -std=c++20 -Iinclude -O3 block_swizzle_test.cpp -o block_swizzle_test.exe \ No newline at end of file diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index cc933012ac..42605f2513 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -5,4 +5,19 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) add_subdirectory(data_type) +add_subdirectory(container) +add_subdirectory(elementwise) +# Not including these tests as there is a bug on gfx90a and gfx942 +# resulting in "GPU core dump" +#add_subdirectory(moe_smoothquant) +add_subdirectory(permute) +add_subdirectory(moe_sorting) add_subdirectory(slice_tile) +add_subdirectory(memory_copy) +add_subdirectory(batched_transpose) +add_subdirectory(smoothquant) +add_subdirectory(topk_softmax) +add_subdirectory(add_rmsnorm2d_rdquant) +# add_subdirectory(layernorm2d) +# add_subdirectory(rmsnorm2d) +add_subdirectory(gemm_block_scale) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt new file mode 100644 index 0000000000..37774f7643 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt @@ -0,0 +1,26 @@ +function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX) + set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "test_ck_tile_add_rmsnorm2d_rdquant_fwd_${SUFFIX}") + message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") + file(GLOB INSTANCE_SRCS instances/*.cpp) + add_test_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd_${SUFFIX}.cpp) + target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) + + set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + + # TODO: we have to turn off this global prop, otherwise the progress bar generated + # by cmake will print too many files, execvp: /bin/sh: Argument list too long + # however, this property may affect global + # TODO: consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +endfunction() + +if(GPU_TARGETS MATCHES "gfx9") + create_tile_add_rmsnorm2d_rdquant_fwd("fp16") + create_tile_add_rmsnorm2d_rdquant_fwd("bf16") +else() + message(DEBUG "Skipping ck tile add_rmsnorm2d_rdquant_fwd tests for current target") +endif() diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp new file mode 100644 index 0000000000..faa134e5c4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +template +struct AddRmsnormRdquantTypeConfig; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +// runtime args +struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct add_rmsnorm2d_rdquant_fwd_traits_ +{ + using InputDataType = ck_tile::remove_cvref_t; + using QuantizedDataType = ck_tile::remove_cvref_t; + + static constexpr auto WarpSize = ck_tile::get_warp_size(); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; +}; + +template +float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a); + +// This is the public API, will be generated by script +struct add_rmsnorm2d_rdquant_fwd_traits +{ + std::string input_data_type; + std::string quantized_data_type; + bool save_x; +}; + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits, + add_rmsnorm2d_rdquant_fwd_args, + const ck_tile::stream_config&); diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc new file mode 100644 index 0000000000..116d3798b9 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/host.hpp" +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("quant", "int8", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string input_data_type = arg_parser.get_str("prec"); + std::string quantized_data_type = arg_parser.get_str("quant"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = AddRmsnormRdquantTypeConfig; + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; + using UnquantYDataType = ck_tile::null_type; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; + + add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; + + add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + float ave_time = add_rmsnorm2d_rdquant_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n + + sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m + + sizeof(QYDataType) * m * n; + + if constexpr(SaveX) + num_byte += sizeof(XDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = InputDataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + if constexpr(SaveX) + { + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err(x_host_dev, + x_host_ref, + std::string("x Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +bool dispatch_by_type(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + const std::string input_data_type = arg_parser.get_str("prec"); + const std::string quantized_data_type = arg_parser.get_str("quant"); + int save_x = arg_parser.get_int("save_x"); + if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser); + } + + return false; +} + +int run_add_rmsnorm2d_rdquant_combinations(std::string const& data_type) +{ + constexpr size_t PARAM_COUNT = 11; + char bufs[PARAM_COUNT][64]; + char* argv[PARAM_COUNT]; + + for(std::size_t i = 0; i < PARAM_COUNT; i++) + { + argv[i] = bufs[i]; + } + + std::vector> params = { + {"-m=99", "-n=13"}, + {"-m=17", "-n=16"}, + {"-m=1", "-n=100"}, + {"-m=4", "-n=128"}, + {"-m=80", "-n=127"}, + {"-m=22", "-n=255", "-stride=256"}, + {"-m=7", "-n=599"}, + {"-m=19", "-n=512"}, + {"-m=33", "-n=313", "-stride=1000"}, + {"-m=11", "-n=510"}, + {"-m=171", "-n=676", "-stride=818"}, + {"-m=91", "-n=636"}, + {"-m=12", "-n=768", "-stride=800"}, + {"-m=100", "-n=766", "-stride=812"}, + {"-m=31", "-n=1024"}, + {"-m=64", "-n=1000", "-stride=1004"}, + {"-m=8", "-n=1501"}, + {"-m=3", "-n=1826"}, + {"-m=5", "-n=2040"}, + {"-m=7", "-n=2734"}, + {"-m=1", "-n=3182"}, + {"-m=9", "-n=4096"}, + {"-m=3", "-n=8192"}, + {"-m=1", "-n=10547"}, + {"-m=3", "-n=17134"}, + }; + + bool result = true; + std::string pr_i = "-prec=" + data_type; + strncpy(bufs[0], "add_rmsnorm2d_rdquant_fwd", 64); + strncpy(bufs[1], pr_i.c_str(), 64); + for(size_t i = 0; i < params.size(); i++) + { + for(size_t j = 0; j < params[i].size(); j++) + { + strncpy(bufs[j + 2], params[i][j].c_str(), 64); + } + int argc = params[i].size() + 2; + + result = dispatch_by_type(argc, argv) && result; + } + return result ? 0 : -1; +} diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp new file mode 100644 index 0000000000..1e0863fa62 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd.inc" + +int main() { return run_add_rmsnorm2d_rdquant_combinations("bf16"); } diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp new file mode 100644 index 0000000000..0a0a4c4f83 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd.inc" + +int main() { return run_add_rmsnorm2d_rdquant_combinations("fp16"); } diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp new file mode 100644 index 0000000000..f695ea30b2 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd x 3p + if(a.n <= 64) { + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 8192) { + if(a.n<8192){ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + else{ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + } + else if(a.n > 8192) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + return r; + // clang-format on +} + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..00df2f5082 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,26 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..2adb54c078 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..39089843a2 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..ddb8e1b354 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..2a87614403 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..045a3b8880 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..1028973e74 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..b8439a0ce9 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..b24b245757 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp new file mode 100644 index 0000000000..14f0ec8525 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..3e3a6d75b9 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..04d735c12c --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,26 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..5893d6c3ee --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..ec9c417bf3 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..5bc8245106 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..c022c62de6 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..19172b0793 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..f491d92787 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..065f0ea4cc --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..be8c6c4de5 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp new file mode 100644 index 0000000000..ad2dfd931e --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp @@ -0,0 +1,41 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..e3afa07fa4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp new file mode 100644 index 0000000000..25b10e1dc4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -0,0 +1,70 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = add_rmsnorm2d_rdquant_fwd_args; + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) +{ + using InputDataType = typename Traits_::InputDataType; + using QuantizedDataType = typename Traits_::QuantizedDataType; + + using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveX, + Traits_::kThreePass>; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 79bd51d65c..f654d1a917 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -242,21 +242,20 @@ class TestCkTileBatchedGemm : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = 1; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = StrideA; - args.stride_B = StrideB; - args.stride_E = StrideC; - args.batch_stride_A = BatchStrideA; - args.batch_stride_B = BatchStrideB; - args.batch_stride_E = BatchStrideC; - args.batch_count = BatchCount; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + 1, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount}; invoke_batched_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/batched_transpose/CMakeLists.txt b/test/ck_tile/batched_transpose/CMakeLists.txt new file mode 100644 index 0000000000..ac8e3dac49 --- /dev/null +++ b/test/ck_tile/batched_transpose/CMakeLists.txt @@ -0,0 +1,33 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + + function (add_batched_transpose_test TARGET_NAME MAIN_SRC) + message(DEBUG "adding ${TARGET_NAME}") + + add_test_executable(${TARGET_NAME} ${MAIN_SRC} batched_transpose_api.cpp) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + # list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + target_compile_options(${TARGET_NAME} PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + + endfunction(add_batched_transpose_test TARGET_NAME MAIN_SRC) + + set(CUSTOM_TARGET_NAME test_ck_tile_batched_transpose) + + add_custom_target(${CUSTOM_TARGET_NAME}) + + add_batched_transpose_test(test_ck_tile_batched_transpose_fp16 batched_transpose_fp16.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp16) + + add_batched_transpose_test(test_ck_tile_batched_transpose_fp8 batched_transpose_fp8.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp8) + + add_batched_transpose_test(test_ck_tile_batched_transpose_bf16 batched_transpose_bf16.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_bf16) + + +else() + message(DEBUG "Skipping ck_tile batched_transpose tests for current target") +endif() diff --git a/example/ck_tile/37_transpose/transpose_example.hpp b/test/ck_tile/batched_transpose/batched_transpose.hpp similarity index 68% rename from example/ck_tile/37_transpose/transpose_example.hpp rename to test/ck_tile/batched_transpose/batched_transpose.hpp index 8128d583ef..bd1abb1191 100644 --- a/example/ck_tile/37_transpose/transpose_example.hpp +++ b/test/ck_tile/batched_transpose/batched_transpose.hpp @@ -1,11 +1,9 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/ops/reduce.hpp" -#include "batched_transpose_kernel.hpp" -#include "block_transpose.hpp" -#include "transpose_policy.hpp" +#include "ck_tile/ops/batched_transpose.hpp" #include #include diff --git a/example/ck_tile/37_transpose/transpose_example.cpp b/test/ck_tile/batched_transpose/batched_transpose.inc similarity index 59% rename from example/ck_tile/37_transpose/transpose_example.cpp rename to test/ck_tile/batched_transpose/batched_transpose.inc index ac27ca7911..30084f5664 100644 --- a/example/ck_tile/37_transpose/transpose_example.cpp +++ b/test/ck_tile/batched_transpose/batched_transpose.inc @@ -1,5 +1,5 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include @@ -10,49 +10,7 @@ #include #include -#include "transpose_example.hpp" - -#if 0 -template -void dump_host_tensor_4d(const ck_tile::HostTensor& x) -{ - auto len = x.get_lengths(); - assert(len.size() == 4); - std::cout << "["; - for(size_t i = 0; i < len[0]; i++) - { - std::cout << i << ": ["; - for(size_t j = 0; j < len[1]; j++) - { - std::cout << j << ": ["; - for(size_t k = 0; k < len[2]; k++) - { - std::cout << k << ": ["; - for(size_t v = 0; v < len[3]; v++) - { - if constexpr(std::is_same_v) - { - auto m = - ck_tile::type_convert(x(std::vector{i, j, k, v})); - - std::cout << m; - if(v != len[3] - 1) - std::cout << ","; - } - else - { - std::cout << x(std::vector{i, j, k, v}) << " "; - } - } - std::cout << "]" << std::endl; - } - std::cout << "]" << std::endl; - } - std::cout << std::endl; - } - std::cout << "--------------------" << std::endl; -} -#endif +#include "batched_transpose.hpp" // different threshold for different dtype template @@ -88,21 +46,23 @@ auto get_elimit(std::string init_method) } } -auto create_args(int argc, char* argv[]) +auto create_args(int argc, char* argv[], int index = 0) { ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "whether do CPU validation or not") .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") - .insert("N", "2", "input batch size. ") + .insert("N", "1", "input batch size. ") .insert("C", "64", "input channel size.") - .insert("H", "1", "input height size.") + .insert("H", "18", "input height size.") .insert("W", "64", "input width size. ") .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "t to 1 will print kernel name"); - bool result = arg_parser.parse(argc, argv); + bool result = arg_parser.parse(argc, argv, index); return std::make_tuple(result, arg_parser); } @@ -115,6 +75,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) int C = args.get_int("C"); int H = args.get_int("H"); int W = args.get_int("W"); + int n_warmup = args.get_int("warmup"); + int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); int seed = args.get_int("seed"); @@ -177,7 +139,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) return a_; }(); - ck_tile::stream_config sc{nullptr, true}; + ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat}; auto ms = batched_transpose(trait, karg, sc); @@ -202,7 +164,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) layout_in.c_str(), ms); if(ms < 0) - printf("not supported\n"); + printf("------------------------------------not " + "supported-------------------------------------\n"); fflush(stdout); if(ms < 0) @@ -227,31 +190,94 @@ bool run_batched_transpose(ck_tile::ArgParser args) rtn &= ck_tile::check_err( y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("-----------------------------------------------------------------------valid:%s--------" + "--------------------------------------------------------------------\n", + rtn ? "y" : "n"); fflush(stdout); return rtn; } -int main(int argc, char** argv) +template +bool run_test_case(int argc, char** argv) { auto [result, args] = create_args(argc, argv); if(!result) - return -1; - std::string prec = args.get_str("pr"); + return false; - bool r = true; - if(prec.compare("fp16") == 0) - { - r &= run_batched_transpose(args); - } - else if(prec.compare("fp8") == 0) - { - r &= run_batched_transpose(args); - } - else - { - std::cerr << "Unsupported data type: " << prec << std::endl; - } - - return r ? 0 : -1; + return run_batched_transpose(args); +} + +template +bool run_test_cases(std::vector>& test_cases) +{ + bool valid = true; + for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx) + { + constexpr int num_args = 7; + char* argv[num_args]; + + assert(test_cases[test_idx].size() == num_args && + "invalid number of arguments in test case"); + + for(std::size_t idx = 0; idx < test_cases[test_idx].size(); ++idx) + { + argv[idx] = test_cases[test_idx][idx].data(); + } + + valid = valid && run_test_case(num_args, argv); + + if(!valid) + break; + } + + return valid; +} + +std::vector> generate_test_cases(const std::string prec) +{ + return { + {"-pr=" + prec, "-N=1", "-C=32", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=64", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=2", "-C=12", "-H=1", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=3", "-C=1334", "-H=1", "-W=37", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=4", "-C=27", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=5", "-C=1234", "-H=1", "-W=12", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, + "-N=128", + "-C=1024", + "-H=64", + "-W=64", + "-layout_in=NCHW", + "-layout_out=NHWC"}, + {"-pr=" + prec, + "-N=128", + "-C=1024", + "-H=64", + "-W=64", + "-layout_in=NHWC", + "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=16", "-C=64", "-H=32", "-W=128", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=16", "-C=64", "-H=128", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, + "-N=1", + "-C=1", + "-H=1024", + "-W=1024", + "-layout_in=NCHW", + "-layout_out=NHWC"}, + {"-pr=" + prec, + "-N=1", + "-C=1", + "-H=1024", + "-W=1024", + "-layout_in=NHWC", + "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=1024", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1024", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}}; } diff --git a/test/ck_tile/batched_transpose/batched_transpose_api.cpp b/test/ck_tile/batched_transpose/batched_transpose_api.cpp new file mode 100644 index 0000000000..973a1967f2 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_api.cpp @@ -0,0 +1,109 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.hpp" + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = block_y; + a.dim_block_w = block_x; + + using block_tile = ck_tile::sequence; + using warp_layout = ck_tile::sequence; + + using ts_problem = + ck_tile::BatchedTransposeProblem; + using ts_pipeline = ck_tile::BatchedTransposePipeline; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); + printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); + printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + kargs.batch, + kargs.height, + kargs.width, + kargs.dim_stride); + + printf("Launching Kernel...\n"); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + printf("Kernel finished...\n"); + + return ave_time; +} + +// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false) + +// Macro that defines one static function per line +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ + } + +FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp8") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true(a, s); + } + } + else if(t.type == "fp16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true(a, s); + } + } + return -1; +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp b/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp new file mode 100644 index 0000000000..42642335f6 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("bf16"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp b/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp new file mode 100644 index 0000000000..5562dd54e8 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("fp16"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp b/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp new file mode 100644 index 0000000000..45e79fb4c2 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("fp8"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/container/CMakeLists.txt b/test/ck_tile/container/CMakeLists.txt new file mode 100644 index 0000000000..50670c83e4 --- /dev/null +++ b/test/ck_tile/container/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp) + if(result EQUAL 0) + target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility) + endif() +endif() \ No newline at end of file diff --git a/test/ck_tile/container/test_tuple_apply.cpp b/test/ck_tile/container/test_tuple_apply.cpp new file mode 100644 index 0000000000..91e0c22895 --- /dev/null +++ b/test/ck_tile/container/test_tuple_apply.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/core.hpp" + +using namespace ck_tile; + +class TestCkTileTupleApply : public ::testing::Test +{ + public: + // Test functors for different scenarios + struct AddFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return (args + ...); + } + }; + + struct MultiplyFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return (args * ...); + } + }; + + struct MaxFunction + { + template + CK_TILE_HOST_DEVICE constexpr T operator()(T a) const + { + return a; + } + + template + CK_TILE_HOST_DEVICE constexpr T operator()(T a, Args... args) const + { + auto rest_max = operator()(args...); + return a > rest_max ? a : rest_max; + } + }; + + struct ReturnTupleFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return make_tuple(args..., sizeof...(args)); + } + }; +}; + +TEST_F(TestCkTileTupleApply, BasicArithmetic) +{ + // Test with simple arithmetic operations + auto t1 = make_tuple(1, 2, 3); + auto result1 = apply(AddFunction{}, t1); + EXPECT_EQ(result1, 6); + + auto t2 = make_tuple(2, 3, 4, 5); + auto result2 = apply(MultiplyFunction{}, t2); + EXPECT_EQ(result2, 120); +} + +TEST_F(TestCkTileTupleApply, SingleElement) +{ + // Test with single element tuple + auto t1 = make_tuple(42); + auto result1 = apply(AddFunction{}, t1); + EXPECT_EQ(result1, 42); + + auto result2 = apply(MultiplyFunction{}, t1); + EXPECT_EQ(result2, 42); +} + +TEST_F(TestCkTileTupleApply, EmptyTuple) +{ + // Test with empty tuple + auto t = tuple<>{}; + auto result = apply([]() { return 100; }, t); + EXPECT_EQ(result, 100); +} + +TEST_F(TestCkTileTupleApply, DifferentTypes) +{ + // Test with different data types + auto t1 = make_tuple(1, 2.5f, 3.0); + auto result1 = apply(AddFunction{}, t1); + EXPECT_FLOAT_EQ(result1, 6.5f); + + // Test with mixed integer and floating point + auto t2 = make_tuple(10, 0.5f); + auto result2 = apply(MultiplyFunction{}, t2); + EXPECT_FLOAT_EQ(result2, 5.0f); +} + +TEST_F(TestCkTileTupleApply, ReturnTuple) +{ + // Test function that returns a tuple + auto t = make_tuple(1, 2, 3); + auto result = apply(ReturnTupleFunction{}, t); + + EXPECT_EQ(result.get<0>(), 1); + EXPECT_EQ(result.get<1>(), 2); + EXPECT_EQ(result.get<2>(), 3); + EXPECT_EQ(result.get<3>(), 3); // size +} + +TEST_F(TestCkTileTupleApply, LambdaFunction) +{ + // Test with lambda functions + auto t1 = make_tuple(5, 10, 15); + auto result1 = apply([](auto a, auto b, auto c) { return a + b + c; }, t1); + EXPECT_EQ(result1, 30); + + // Test lambda with capture + int multiplier = 2; + auto result2 = + apply([multiplier](auto a, auto b) { return (a + b) * multiplier; }, make_tuple(3, 7)); + EXPECT_EQ(result2, 20); +} + +TEST_F(TestCkTileTupleApply, ConstexprContext) +{ + // Test in constexpr context + constexpr auto t = make_tuple(2, 3, 4); + constexpr auto result = apply(MultiplyFunction{}, t); + static_assert(result == 24, "Constexpr apply should work"); + EXPECT_EQ(result, 24); +} + +TEST_F(TestCkTileTupleApply, ReferenceTypes) +{ + // Test with reference types using tie + int a = 1, b = 2, c = 3; + auto ref_tuple = tie(a, b, c); + + // Function that modifies references + apply( + [](auto& x, auto& y, auto& z) { + x += 10; + y += 20; + z += 30; + }, + ref_tuple); + + EXPECT_EQ(a, 11); + EXPECT_EQ(b, 22); + EXPECT_EQ(c, 33); +} + +TEST_F(TestCkTileTupleApply, MoveSemantics) +{ + // Test with move semantics + auto t = make_tuple(1, 2, 3); + auto result = apply(AddFunction{}, std::move(t)); + EXPECT_EQ(result, 6); +} + +TEST_F(TestCkTileTupleApply, NumberTypes) +{ + // Test with ck_tile::number types + auto t = make_tuple(number<1>{}, number<2>{}, number<3>{}); + auto result = apply([](auto a, auto b, auto c) { return a + b + c; }, t); + EXPECT_EQ(result, 6); +} + +TEST_F(TestCkTileTupleApply, ElementwiseOperation) +{ + // Test simulating elementwise operations + auto input1 = make_tuple(1.0f, 2.0f, 3.0f); + auto input2 = make_tuple(4.0f, 5.0f, 6.0f); + + auto add_elementwise = [](const auto& a, const auto& b) { + return apply( + [&b](auto... args_a) { + return apply( + [args_a...](auto... args_b) { return make_tuple((args_a + args_b)...); }, b); + }, + a); + }; + + auto result = add_elementwise(input1, input2); + + EXPECT_FLOAT_EQ(result.get<0>(), 5.0f); + EXPECT_FLOAT_EQ(result.get<1>(), 7.0f); + EXPECT_FLOAT_EQ(result.get<2>(), 9.0f); +} + +template +class TestCkTileTupleApplySize : public TestCkTileTupleApply +{ + protected: + static constexpr int Size = T::value; +}; + +using TupleSizes = ::testing::Types, + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant>; + +TYPED_TEST_SUITE(TestCkTileTupleApplySize, TupleSizes); + +TYPED_TEST(TestCkTileTupleApplySize, GeneratedTupleSum) +{ + constexpr int N = TypeParam::value; + + // Generate tuple with values 1, 2, 3, ..., N + constexpr auto t = generate_tuple([](auto i) { return i.value + 1; }, number{}); + + // Sum all elements + constexpr auto result = apply(TestCkTileTupleApply::AddFunction{}, t); + + // Expected sum: 1 + 2 + ... + N = N*(N+1)/2 + constexpr int expected = N * (N + 1) / 2; + static_assert(result == expected); +} diff --git a/test/ck_tile/data_type/test_pk_int4.cpp b/test/ck_tile/data_type/test_pk_int4.cpp index 4e9fb20efc..1ccae88112 100644 --- a/test/ck_tile/data_type/test_pk_int4.cpp +++ b/test/ck_tile/data_type/test_pk_int4.cpp @@ -36,8 +36,8 @@ TEST(PackedInt4, ConvertToHalf) const half_t first_input_val = ck_tile::type_convert(7.f); const half_t second_input_val = ck_tile::type_convert(-1.f); #else - const half_t first_input_val = ck_tile::type_convert(-1.f); - const half_t second_input_val = ck_tile::type_convert(7.f); + const half_t first_input_val = ck_tile::type_convert(-1.f); + const half_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); @@ -53,8 +53,8 @@ TEST(PackedInt4, ConvertToBHalf) const bf16_t first_input_val = ck_tile::type_convert(7.f); const bf16_t second_input_val = ck_tile::type_convert(-1.f); #else - const bf16_t first_input_val = ck_tile::type_convert(-1.f); - const bf16_t second_input_val = ck_tile::type_convert(7.f); + const bf16_t first_input_val = ck_tile::type_convert(-1.f); + const bf16_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt new file mode 100644 index 0000000000..d22a30ff56 --- /dev/null +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) + if(result EQUAL 0) + target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) + endif() +endif() \ No newline at end of file diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp new file mode 100644 index 0000000000..7013792335 --- /dev/null +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include // For std::abs +#include +#include // For std::is_same_v, std::is_floating_point_v +#include // For std::index_sequence, std::forward + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" +#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +// Traits to get number of inputs for an elementwise operation +template +struct elementwise_op_traits; + +template <> +struct elementwise_op_traits +{ + static constexpr int num_inputs = 2; +}; +template <> +struct elementwise_op_traits +{ + static constexpr int num_inputs = 1; +}; + +template +auto make_uniform_array_with_factory(F&& factory) +{ + return [&](std::index_sequence) { + return std::array, D>{factory(Is)...}; + }(std::make_index_sequence{}); +} + +template +class TestCkTileElementwise : public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using YDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using ElementwiseOpType = std::tuple_element_t<3, Tuple>; + using BlockWarps_ = std::tuple_element_t<4, Tuple>; + using BlockTile_ = std::tuple_element_t<5, Tuple>; + using WarpTile_ = std::tuple_element_t<6, Tuple>; + using TestElementWiseShape = + ck_tile::ElementWiseShape; + static constexpr int NumInputs = elementwise_op_traits::num_inputs; + + void RunTest(ck_tile::index_t total_m_elements) + { + // Dims and Strides (1D example) + auto lens = ck_tile::make_tuple(total_m_elements); + auto strides = ck_tile::make_tuple( + static_cast(1)); // Strides for the single dimension + + // Host Tensors + auto h_xs = make_uniform_array_with_factory([&](std::size_t) { + auto ret = ck_tile::HostTensor({total_m_elements}); + ck_tile::FillUniformDistribution{0.f, 5.f}(ret); + return ret; + }); + ck_tile::HostTensor h_y({total_m_elements}); + h_y.SetZero(); + ck_tile::HostTensor h_y_ref({total_m_elements}); + h_y_ref.SetZero(); + + // Device Buffers + auto d_xs_mems_owner = make_uniform_array_with_factory( + [&](std::size_t i) { return ck_tile::DeviceMem(h_xs[i]); }); + for(int i = 0; i < NumInputs; ++i) + { + d_xs_mems_owner[i].ToDevice(h_xs[i].data()); + } + + ck_tile::DeviceMem d_y_mem(h_y); + d_y_mem.SetZero(); + + auto d_x_ptrs_tuple = [&](std::index_sequence) { + return ck_tile::make_tuple( + static_cast(d_xs_mems_owner[Is].GetDeviceBuffer())...); + }(std::make_index_sequence{}); + + YDataType* p_y_device = static_cast(d_y_mem.GetDeviceBuffer()); + + // Problem and Policy + using Problem = ck_tile::ElementWisePipelineProblem; + using Policy = ck_tile::ElementWiseDefaultPolicy; + + ck_tile::ElementWiseKernel ew_kernel; + + // Launch configuration + ck_tile::index_t grid_size = + (total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM; + dim3 grid(grid_size, 1, 1); + dim3 block(TestElementWiseShape::kBlockSize, 1, 1); + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log + + // Check if the kernel configuration is supported + if(!ew_kernel.IsSupportedArgument(lens)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel // MinBlockPerCu + (ew_kernel, + grid, + block, + 0, // actual shared memory + lens, + strides, // input strides + strides, // output strides + d_x_ptrs_tuple, + p_y_device)); + + d_y_mem.FromDevice(h_y.data()); + + // Reference computation on host + ElementwiseOpType op_host; + for(ck_tile::index_t i = 0; i < total_m_elements; ++i) + { + auto get_host_op_args = [&](std::index_sequence) { + return ck_tile::make_tuple(static_cast(h_xs[Is](i))...); + }(std::make_index_sequence{}); + + YDataType temp_y_val; + ck_tile::apply( + [&](auto&&... host_input_args) { + op_host(temp_y_val, + std::forward(host_input_args)...); + }, + get_host_op_args); + h_y_ref(i) = temp_y_val; + } + + // Check results + check_err(h_y, h_y_ref, "Error: Incorrect results!", 1e-5, 1e-5); + } +}; + +// Shape parameters (can be shared or varied per test type) +using Shape1_BlockWarps = ck_tile::sequence<1>; // 1D warp arrangement in M +using Shape1_BlockTile = ck_tile::sequence<256>; // M-dimension of block tile +using Shape1_WarpTile = ck_tile::sequence<64>; // M-dimension of warp tile + +// Test configurations +using TestConfig_F32_Add = std::tuple; + +using TestConfig_F32_Relu = std::tuple; + +using TestConfig_F16_Add = std::tuple; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes); + +TYPED_TEST(TestCkTileElementwise, RunElementwise_1024) { this->RunTest(1024); } + +TYPED_TEST(TestCkTileElementwise, RunElementwise_513) +{ + EXPECT_THROW((this->RunTest(513)), + std::runtime_error); // Test with an input size that's not a multiple of kVectorM +} + +TYPED_TEST(TestCkTileElementwise, RunElementwise_516) +{ + this->RunTest(516); // Test with an input size that's not a multiple of blockM +} + +TYPED_TEST(TestCkTileElementwise, RunElementwise_Small_32) +{ + this->RunTest(32); // Test with a very small size +} diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 8f880b8fde..6cbdc1a24e 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -20,6 +20,16 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + + + add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() @@ -27,4 +37,13 @@ endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp new file mode 100644 index 0000000000..af2cb398f5 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp new file mode 100644 index 0000000000..fd8c28ef17 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp new file mode 100644 index 0000000000..4a93d6046a --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("fp16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp new file mode 100644 index 0000000000..fd8c28ef17 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc new file mode 100644 index 0000000000..4321709ea5 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + if constexpr(Persistent) + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + +bool run_gemm_test(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp16") + { + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "fp8") + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int run_gemm_combinations(std::string const& data_type) +{ + // Define possible values for each parameter + std::vector m_values = {"128", "1024"}; + std::vector n_values = {"128", "2048"}; + std::vector k_values = {"64", "128"}; + std::vector prec_values = {data_type}; + + // We'll store all our arguments as strings first + std::vector arg_strings = {"./bin/tile_example_gemm_basic", + "", // m placeholder + "", // n placeholder + "", // k placeholder + "-stride_a=0", + "-stride_b=0", + "-stride_c=0", + "", // prec placeholder + "-v=2", + "-warmup=0", + "-repeat=1"}; + + // Create an array of const char pointers for argv + constexpr size_t ARG_COUNT = 11; + constexpr size_t ARG_MAX_LEN = 64; + char args[ARG_COUNT][ARG_MAX_LEN]; + char* argv[ARG_COUNT]; + + // Run all combinations + bool is_success = true; + for(const auto& m : m_values) + { + arg_strings[1] = "-m=" + m; + + for(const auto& n : n_values) + { + arg_strings[2] = "-n=" + n; + + for(const auto& k : k_values) + { + arg_strings[3] = "-k=" + k; + + for(const auto& prec : prec_values) + { + arg_strings[7] = "-prec=" + prec; + + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) + { + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } + + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; + + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; + } + } + } + } + } + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc new file mode 100644 index 0000000000..a967b92e7f --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -0,0 +1,456 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +void permute_tensor_b(Tensor& tensor) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); + const ck_tile::index_t K0 = K / K1; + + Tensor tensor_copy = tensor; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); + } + } + } +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat, + bool persistent) +{ + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; + + float ave_time; + if(persistent) + { + ave_time = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } + else + { + ave_time = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +bool run_gemm_test_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + static_assert(!GemmConfig::PermuteA, "Not implemented"); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + // memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp new file mode 100644 index 0000000000..f64d3e092b --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 + +class ArgumentsNotSupportedException : public std::logic_error +{ + public: + explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {} +}; + +// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp new file mode 100644 index 0000000000..0673272f5f --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("bf16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp new file mode 100644 index 0000000000..70eae12e82 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp new file mode 100644 index 0000000000..8ea192c7f3 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("fp16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp new file mode 100644 index 0000000000..20414b4fec --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("fp8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc new file mode 100644 index 0000000000..860541ef18 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + +template