From ec55fdf2b3db8d86d2f42a4726840b8b6f269cbc Mon Sep 17 00:00:00 2001 From: Astha Rai Date: Mon, 22 Sep 2025 18:49:21 +0000 Subject: [PATCH] Adding in bf8 streamk example in CK Tile --- .github/CODEOWNERS | 2 +- .github/scripts/therock_configure_ci.py | 24 +- .github/workflows/therock-ci-linux.yml | 51 +- .github/workflows/therock-ci.yml | 9 +- .github/workflows/therock-test-component.yml | 71 ++ .github/workflows/therock-test-packages.yml | 40 +- CHANGELOG.md | 6 + CMakeLists.txt | 9 + Dockerfile | 29 +- Dockerfile.compiler | 2 +- Jenkinsfile | 72 +- codegen/CMakeLists.txt | 3 +- example/01_gemm/gemm_xdl_fp16.cpp | 4 +- example/01_gemm/gemm_xdl_fp16_v2.cpp | 8 +- example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp | 6 +- example/01_gemm/gemm_xdl_wavelet_fp16.cpp | 4 +- example/01_gemm/run_gemm_example.inc | 16 +- .../gemm_bias_relu_xdl_fp16.cpp | 5 +- .../gemm_add_add_fastgelu_xdl_bf16.cpp | 4 +- .../gemm_add_add_fastgelu_xdl_fp16.cpp | 4 +- .../gemm_add_add_fastgelu_xdl_fp32.cpp | 4 +- .../run_gemm_add_add_fastgelu_example.inc | 28 +- .../common.hpp | 2 +- .../run_convnd_fwd_max_example.inc | 2 +- .../12_reduce/reduce_blockwise_two_call.cpp | 23 +- example/13_pool2d_fwd/pool2d_fwd_common.hpp | 4 +- .../gemm_dl_quantization_int8.cpp | 6 +- .../15_grouped_gemm/grouped_gemm_xdl_bf16.cpp | 4 +- .../15_grouped_gemm/grouped_gemm_xdl_fp16.cpp | 4 +- .../15_grouped_gemm/grouped_gemm_xdl_fp32.cpp | 4 +- .../batched_gemm_reduce_xdl_fp16.cpp | 6 +- ...as_relu_add_layernorm_xdl_welford_fp16.cpp | 6 +- example/22_cgemm/cgemm_xdl_bf16.cpp | 12 +- example/22_cgemm/cgemm_xdl_fp16.cpp | 12 +- .../run_batched_gemm_example.inc | 6 +- ..._batched_gemm_example_fp16int4_b_scale.inc | 6 +- .../run_batched_gemm_example_rowwise.inc | 6 +- .../gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp | 16 +- .../gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp | 16 +- .../run_contraction_bilinear_example.inc | 14 +- .../run_contraction_scale_example.inc | 12 +- .../grouped_gemm_bias_e_permute_xdl_fp16.cpp | 17 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 16 +- .../batched_gemm_bias_e_permute_xdl_fp16.cpp | 20 +- .../30_grouped_conv_fwd_multiple_d/common.hpp | 36 +- .../common_wmma.hpp | 36 +- ...atched_gemm_gemm_wmma_cshuffle_v3_base.inc | 4 + .../run_batched_gemm_gemm_example.inc | 6 +- ...run_batched_gemm_gemm_wmma_cshuffle_v3.inc | 12 +- ...le_scale_softmax_gemm_permute_xdl_fp16.cpp | 14 +- .../run_batched_gemm_scale_softmax_gemm.inc | 6 +- ...atched_gemm_scale_softmax_gemm_permute.inc | 12 +- ...d_gemm_scale_softmax_gemm_permute_wmma.inc | 33 +- .../run_cross_attention_wmma.inc | 35 +- ...rouped_gemm_scale_softmax_gemm_permute.inc | 32 +- ...n_grouped_query_attention_forward_wmma.inc | 33 +- ...run_multi_query_attention_forward_wmma.inc | 33 +- .../run_self_attention_wmma.inc | 35 +- .../gemm_xdl_splitk_reduce_bf16.cpp | 8 +- .../gemm_xdl_splitk_reduce_bf16A_i8B.cpp | 8 +- .../gemm_xdl_splitk_reduce_multi_d_bf16.cpp | 8 +- .../gemm_xdl_splitk_reduce_multi_d_fp16.cpp | 8 +- .../sparse_embedding3_forward_layernorm.cpp | 253 ++--- ...ed_gemm_add_add_relu_gemm_add_xdl_fp16.cpp | 6 +- ...rouped_conv_bwd_data_bias_relu_example.inc | 3 +- ...d_bias_perchannel_quantization_example.inc | 3 +- ...fwd_bias_perlayer_quantization_example.inc | 3 +- ...2d_fwd_perchannel_quantization_example.inc | 3 +- .../splitk_gemm_bias_e_permute_xdl_fp16.cpp | 21 +- .../splitk_gemm_bias_e_permute_xdl_fp32.cpp | 25 +- .../elementwise_binary_4D_fp16.cpp | 28 +- .../elementwise_permute_4D_fp16.cpp | 8 +- .../elementwise_permute_4D_fp16_col.cpp | 8 +- .../elementwise_permute_4D_fp16_row.cpp | 9 +- .../elementwise_permute_4D_fp32_col.cpp | 8 +- .../elementwise_permute_4D_fp32_row.cpp | 9 +- .../elementwise_trinary_4D_fp16.cpp | 13 +- .../elementwise_layernorm_blockwise.cpp | 21 +- .../gemm_add_multiply_xdl_fp16.cpp | 4 +- .../run_gemm_add_multiply_example.inc | 43 +- .../gemm_bias_softmax_gemm_permute_xdl.cpp | 16 +- example/48_pool3d_fwd/pool3d_fwd_common.hpp | 9 +- .../49_maxpool2d_bwd/maxpool2d_bwd_common.hpp | 4 +- .../51_avgpool3d_bwd/avgpool3d_bwd_common.hpp | 9 +- .../54_groupnorm_bwd/groupnorm_bwd_fp32.cpp | 21 +- example/60_gemm_multi_ABD/CMakeLists.txt | 4 + ...m_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp | 307 ++++++ .../gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp | 299 ++++++ .../gemm_multi_ABD_wmma_fp16.cpp | 362 +++++++ ...BD_wmma_multiply_bias_fastgelu_bf16_i8.cpp | 296 ++++++ ...mm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp | 45 +- .../gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp | 45 +- ...ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp | 47 +- .../contraction_multi_ABD_xdl_fp16.cpp | 20 +- .../contraction_multi_ABD_xdl_fp8.cpp | 21 +- ...aleadd_scaleadd_relu_bcasted_bias_fp16.cpp | 11 +- example/64_fpAintB_gemm/run_gemm_example.inc | 3 +- ...multiply_multiply_xdl_fp16_bpreshuffle.cpp | 26 +- .../gemm_multiply_multiply_xdl_fp8.cpp | 24 +- ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 26 +- .../moe_gemm1_xdl_fp8.cpp | 13 +- .../moe_gemm1_xdl_fp8_blockscale.cpp | 18 +- .../moe_gemm1_xdl_pk_i4.cpp | 13 +- .../moe_gemm2_xdl_fp8.cpp | 13 +- .../moe_gemm2_xdl_fp8_blockscale.cpp | 12 +- .../moe_gemm2_xdl_pk_i4.cpp | 13 +- .../common_instances.hpp | 4 +- ...n_complex_contraction_bilinear_example.inc | 37 +- .../moe_gemm1_xdl_mx_fp4.cpp | 15 +- .../moe_gemm1_xdl_mx_fp4_bns.cpp | 15 +- .../moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp | 18 +- .../moe_gemm2_xdl_mx_fp4.cpp | 13 +- .../moe_gemm2_xdl_mx_fp4_bns.cpp | 15 +- .../moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp | 15 +- example/ck_tile/01_fmha/CMakeLists.txt | 6 +- example/ck_tile/01_fmha/README.md | 27 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 7 +- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 7 + .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 97 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 128 ++- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 10 + .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 15 +- .../codegen/ops/fmha_pagedkv_prefill.py | 16 +- example/ck_tile/01_fmha/example_fmha_bwd.cpp | 8 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 58 +- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 148 ++- example/ck_tile/01_fmha/fmha_bwd.hpp | 28 +- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 22 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 75 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 405 ++++++-- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 5 + example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 4 +- .../ck_tile/01_fmha/script/benchmark_fwd.sh | 33 + .../01_fmha/script/benchmark_fwd_v3.sh | 17 + .../script/fmha_bwd_known_fails_gfx90a.txt | 2 - .../script/fmha_bwd_known_fails_gfx942.txt | 2 - .../script/fmha_bwd_known_fails_gfx950.txt | 31 - .../script/fmha_fwd_known_fails_gfx950.txt | 4 - .../ck_tile/01_fmha/script/run_full_test.sh | 8 +- .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 40 +- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 138 ++- .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 4 +- example/ck_tile/03_gemm/gemm_utils.hpp | 1 + .../gemm_weight_preshuffle_invoker.hpp | 2 +- .../03_gemm/universal_gemm_invoker.hpp | 4 +- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 4 +- example/ck_tile/17_grouped_gemm/README.md | 167 +-- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 5 +- .../17_grouped_gemm/grouped_gemm_multi_d.cpp | 286 +++++ .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 223 ++++ .../17_grouped_gemm/quant_grouped_gemm.cpp | 20 +- .../quant_run_grouped_gemm_example.inc | 12 + .../run_grouped_gemm_example.inc | 34 +- .../run_grouped_gemm_multi_d_example.inc | 417 ++++++++ .../20_grouped_convolution/CMakeLists.txt | 3 + .../grouped_convolution_backward_data.cpp | 190 +--- ...uped_convolution_backward_data_invoker.hpp | 144 +++ .../grouped_convolution_backward_weight.cpp | 217 +--- ...ed_convolution_backward_weight_invoker.hpp | 145 +++ ..._convolution_backward_weight_two_stage.cpp | 67 ++ ...tion_backward_weight_two_stage_invoker.hpp | 215 ++++ .../grouped_convolution_forward.cpp | 177 +--- .../grouped_convolution_forward_invoker.hpp | 135 +++ ...n_grouped_convolution_bwd_data_example.inc | 77 +- ...grouped_convolution_bwd_weight_example.inc | 124 ++- .../run_grouped_convolution_fwd_example.inc | 77 +- .../21_elementwise/elementwise_example.cpp | 4 +- .../elementwise_example_add_4d.cpp | 4 +- .../elementwise_example_transpose.cpp | 4 +- .../elementwise_example_unary.cpp | 4 +- example/ck_tile/38_block_scale_gemm/README.md | 5 +- .../38_block_scale_gemm/gemm_quant_basic.cpp | 127 ++- .../38_block_scale_gemm/gemm_utils.hpp | 28 +- .../run_gemm_quant_example.inc | 120 ++- .../ck_tile/40_streamk_gemm/CMakeLists.txt | 20 +- example/ck_tile/40_streamk_gemm/README.md | 37 + .../ck_tile/40_streamk_gemm/gemm_utils.hpp | 247 +---- .../40_streamk_gemm/run_gemm_example.inc | 395 +++---- .../ck_tile/40_streamk_gemm/streamk_gemm.cpp | 7 + .../40_streamk_gemm/streamk_gemm_basic.cpp | 205 ++++ include/ck/ck.hpp | 1 + include/ck/host_utility/flush_cache.hpp | 149 ++- ...volution_host_tensor_descriptor_helper.hpp | 19 +- include/ck/library/utility/host_tensor.hpp | 509 ++++++++- .../ck/library/utility/validation_common.hpp | 50 - .../multi_index_transform.hpp | 194 +++- .../multi_index_transform_helper.hpp | 55 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 7 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 56 +- ...kwise_gemm_pipeline_xdlops_v1_ab_scale.hpp | 60 +- ...ckwise_gemm_pipeline_xdlops_v1_b_scale.hpp | 28 +- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 96 +- ...kwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 46 +- ...ckwise_gemm_pipeline_xdlops_v2_b_scale.hpp | 98 +- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 34 +- ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 34 +- ...ckwise_gemm_pipeline_xdlops_v3_b_scale.hpp | 36 +- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 46 +- ...ckwise_gemm_pipeline_xdlops_v4_b_scale.hpp | 48 +- .../blockwise_gemm_pipeline_xdlops_v5.hpp | 46 +- .../gpu/device/device_gemm_multiple_abd.hpp | 151 ++- .../device/device_grouped_conv_bwd_weight.hpp | 2 + ...ice_grouped_conv_bwd_weight_multiple_d.hpp | 2 + .../device_batched_gemm_wmma_cshuffle_v3.hpp | 69 +- ...ice_gemm_multiple_abd_wmma_cshuffle_v3.hpp | 422 ++++++++ ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 47 +- ...evice_gemm_multiple_d_wmma_cshuffle_v3.hpp | 24 +- .../impl/device_gemm_wmma_cshuffle_v3.hpp | 24 +- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 25 +- .../device_gemm_wmma_cshuffle_v3_common.hpp | 42 +- .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 28 +- .../impl/device_gemm_xdl_skip_b_lds.hpp | 20 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 16 +- ...e_grouped_conv_bwd_weight_explicit_xdl.hpp | 47 +- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 8 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 8 + ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 9 + ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 12 +- .../gpu/device/tensor_layout.hpp | 165 +-- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 40 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 9 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 155 ++- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 159 ++- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 386 +++++-- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 45 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 45 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 18 + .../transform_conv_bwd_data_to_gemm_v1.hpp | 60 +- include/ck/utility/amd_ck_fp8.hpp | 39 +- include/ck/utility/amd_xdlops.hpp | 4 +- include/ck/utility/data_type.hpp | 12 +- include/ck/utility/debug.hpp | 4 +- include/ck/utility/dtype_vector.hpp | 16 + include/ck/utility/e8m0.hpp | 2 + include/ck/utility/f8_utils.hpp | 37 +- include/ck/utility/magic_division.hpp | 4 - include/ck/utility/numeric_limits.hpp | 3 +- include/ck/utility/numeric_utils.hpp | 2 + include/ck/utility/random_gen.hpp | 6 +- include/ck/utility/type_convert.hpp | 16 +- .../core/arch/amd_buffer_addressing.hpp | 77 +- .../arch/amd_buffer_addressing_builtins.hpp | 70 +- include/ck_tile/core/arch/arch.hpp | 20 +- include/ck_tile/core/numeric/bfloat16.hpp | 34 +- include/ck_tile/core/tensor/buffer_view.hpp | 7 +- include/ck_tile/core/tensor/load_tile.hpp | 2 +- include/ck_tile/core/tensor/tensor_view.hpp | 2 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 2 +- include/ck_tile/core/tensor/tile_window.hpp | 2 +- include/ck_tile/core/utility/philox_rand.hpp | 16 +- .../core/utility/transpose_vectors.hpp | 10 +- .../reference_batched_dropout_randval.hpp | 12 +- .../ck_tile/host/reference/reference_gemm.hpp | 56 +- ...norm2d_rdquant_fwd_pipeline_three_pass.hpp | 8 +- .../kernel/batched_transpose_kernel.hpp | 6 +- .../ops/common/load_interleaved_pk_type.hpp | 58 ++ .../elementwise/kernel/elementwise_kernel.hpp | 20 +- .../unary_element_wise_operation.hpp | 73 +- .../ops/epilogue/cshuffle_epilogue.hpp | 125 ++- .../ops/flatmm/kernel/flatmm_kernel.hpp | 4 +- .../ck_tile/ops/fmha/block/block_dropout.hpp | 713 ++++++------- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 4 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 41 +- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 4 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 242 +++-- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 4 +- .../fmha_fwd_splitkv_combine_kernel.hpp | 4 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 4 +- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 60 +- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 16 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 32 +- ...ck_fmha_bwd_dq_dk_dv_pipeline_selector.hpp | 5 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 22 +- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 28 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 14 +- .../block_fmha_bwd_pipeline_problem.hpp | 12 +- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 4 +- ...gedkv_pipeline_qr_ks_vs_default_policy.hpp | 25 +- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 4 +- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 6 +- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1 + ..._ks_vs_whole_k_prefetch_default_policy.hpp | 25 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 125 ++- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 24 +- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 17 + .../fused_moe/kernel/fused_moegemm_kernel.hpp | 8 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 16 +- .../fused_moegemm_pipeline_flatmm_uk.hpp | 12 +- include/ck_tile/ops/gemm.hpp | 2 + .../block/block_universal_gemm_as_bs_cr.hpp | 37 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 20 +- .../ops/gemm/kernel/gemm_multi_abd_kernel.hpp | 12 + .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 8 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 135 ++- .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 6 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 46 +- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 531 ++++++++++ ...ine_ag_bg_cr_comp_async_default_policy.hpp | 101 ++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 17 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 2 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 4 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 108 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 18 +- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 28 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 28 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 18 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 1 + .../warp/warp_gemm_attribute_mfma_impl.hpp | 131 ++- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 5 + include/ck_tile/ops/gemm_group_quant.hpp | 21 - include/ck_tile/ops/gemm_quant.hpp | 24 + ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 191 ++++ .../block_universal_gemm_as_aquant_bs_cr.hpp | 36 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 31 +- .../kernel/gemm_quant_kernel.hpp | 267 ++++- .../kernel/grouped_gemm_quant_kernel.hpp | 6 +- .../gemm_aquant_pipeline_ag_bg_cr_base.hpp | 0 .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 3 - .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 2 +- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 0 .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 3 - .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 2 +- .../pipeline/gemm_group_quant_utils.hpp | 0 .../pipeline/gemm_quant_pipeline_problem.hpp | 33 +- ...p_bquant_pipeline_ag_bg_cr_base_policy.hpp | 60 ++ .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 471 +++++++++ .../pipeline/tile_gemm_quant_traits.hpp | 17 +- ...ouped_convolution_backward_data_kernel.hpp | 52 +- ...ped_convolution_backward_weight_kernel.hpp | 86 +- .../grouped_convolution_forward_kernel.hpp | 39 +- .../utils/grouped_convolution_utils.hpp | 45 +- .../utils/transform_conv_bwd_data_to_gemm.hpp | 41 +- .../transform_conv_bwd_weight_to_gemm.hpp | 38 +- .../utils/transform_conv_fwd_to_gemm.hpp | 147 ++- .../kernel/image_to_column_kernel.hpp | 6 +- .../layernorm2d_fwd_pipeline_two_pass.hpp | 6 +- .../ops/reduce/kernel/reduce2d_kernel.hpp | 4 +- .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 6 +- .../kernel/moe_smoothquant_kernel.hpp | 2 +- .../smoothquant_pipeline_two_pass.hpp | 6 +- .../kernel/topk_softmax_kernel.hpp | 6 +- include/ck_tile/utility/json_dump.hpp | 18 +- .../cpu/reference_moe_gemm.hpp | 8 +- .../cpu/reference_moe_gemm2.hpp | 7 +- .../device_operation_instance_factory.hpp | 3 + .../gpu/gemm_multi_abd.hpp | 954 ++++++++++++++++- ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 118 +++ ...grouped_conv_fwd_xdl_bilinear_instance.hpp | 35 + ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 22 + .../device_grouped_conv_fwd_xdl_instance.hpp | 19 + ...ped_conv_fwd_xdl_large_tensor_instance.hpp | 22 + ...vice_grouped_conv_fwd_xdl_mem_instance.hpp | 38 +- ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 22 + ...ce_grouped_conv_fwd_xdl_scale_instance.hpp | 35 + ...uped_conv_fwd_xdl_scaleadd_ab_instance.hpp | 24 +- .../gpu/grouped_convolution_backward_data.hpp | 12 + .../grouped_convolution_backward_data_xdl.inc | 84 ++ .../gpu/grouped_convolution_forward.hpp | 51 +- ...grouped_convolution_forward_bias_clamp.hpp | 110 +- ...ped_convolution_forward_bias_clamp_xdl.inc | 176 ++++ .../grouped_convolution_forward_bilinear.hpp | 28 +- .../gpu/grouped_convolution_forward_clamp.hpp | 109 +- .../grouped_convolution_forward_clamp_xdl.inc | 176 ++++ .../grouped_convolution_forward_comp_xdl.inc | 31 + ...uped_convolution_forward_mem_inter_xdl.inc | 31 + ...uped_convolution_forward_mem_intra_xdl.inc | 30 + .../gpu/grouped_convolution_forward_scale.hpp | 27 +- .../gpu/grouped_convolution_forward_xdl.inc | 16 + ...d_convolution_forward_xdl_large_tensor.inc | 32 + ..._convolution_forward_xdl_merged_groups.inc | 32 + .../gpu/gemm_multi_abd/CMakeLists.txt | 32 +- ...multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp | 109 ++ ..._abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp | 85 ++ ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp | 111 ++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 59 ++ ...iply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++ .../grouped_conv2d_bwd_data/CMakeLists.txt | 3 + ...xc_nhwgk_bf16_optimized_loads_instance.cpp | 49 + ...yxc_nhwgk_f16_optimized_loads_instance.cpp | 49 + ...yxc_nhwgk_f32_optimized_loads_instance.cpp | 49 + .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 6 + ...wgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp | 68 ++ ...dl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 66 ++ ...or_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 41 + ...kyxc_nhwgk_f32_tf32_mem_inter_instance.cpp | 70 ++ ...kyxc_nhwgk_f32_tf32_mem_intra_instance.cpp | 70 ++ ...ps_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 50 + .../CMakeLists.txt | 50 +- ...hwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in | 82 ++ ...sor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in | 54 + ...gkyxc_nhwgk_f32_tf32_mem_inter_instance.in | 85 ++ ...gkyxc_nhwgk_f32_tf32_mem_intra_instance.in | 85 ++ ...ups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in | 68 ++ .../CMakeLists.txt | 6 + ...gc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp | 65 ++ ...l_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 62 ++ ...r_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 43 + ...yxc_nhwgk_fp32_tf32_mem_inter_instance.cpp | 67 ++ ...yxc_nhwgk_fp32_tf32_mem_intra_instance.cpp | 67 ++ ...s_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 56 + .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 6 + ...gc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp | 65 ++ ...l_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 62 ++ ...r_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 43 + ...yxc_nhwgk_fp32_tf32_mem_inter_instance.cpp | 67 ++ ...yxc_nhwgk_fp32_tf32_mem_intra_instance.cpp | 67 ++ ...s_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 55 + .../grouped_conv3d_bwd_data/CMakeLists.txt | 3 + ...c_ndhwgk_bf16_optimized_loads_instance.cpp | 49 + ...xc_ndhwgk_f16_optimized_loads_instance.cpp | 49 + ...xc_ndhwgk_f32_optimized_loads_instance.cpp | 49 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 9 +- ...c_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp | 57 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 41 + ...yxc_ndhwgk_f32_tf32_mem_inter_instance.cpp | 59 ++ ...yxc_ndhwgk_f32_tf32_mem_intra_instance.cpp | 59 ++ ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 49 + .../CMakeLists.txt | 47 + ...gc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in | 82 ++ ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 54 + ...zyxc_ndhwgk_f32_tf32_mem_inter_instance.in | 85 ++ ...zyxc_ndhwgk_f32_tf32_mem_intra_instance.in | 85 ++ ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 68 ++ .../CMakeLists.txt | 5 + ..._gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp | 64 ++ ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 43 + ...xc_ndhwgk_fp32_tf32_mem_inter_instance.cpp | 65 ++ ...xc_ndhwgk_fp32_tf32_mem_intra_instance.cpp | 65 ++ ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 53 + .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 57 + .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 5 + ..._gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp | 63 ++ ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 43 + ...xc_ndhwgk_fp32_tf32_mem_inter_instance.cpp | 65 ++ ...xc_ndhwgk_fp32_tf32_mem_intra_instance.cpp | 65 ++ ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 53 + .../grouped_conv3d_fwd_scale/CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 57 + library/src/utility/host_tensor.cpp | 23 +- .../profiler/profile_avg_pool2d_bwd_impl.hpp | 18 +- .../profiler/profile_avg_pool3d_bwd_impl.hpp | 17 +- ...le_batched_gemm_add_relu_gemm_add_impl.hpp | 6 +- .../profile_batched_gemm_b_scale_impl.hpp | 6 +- ...ed_gemm_bias_softmax_gemm_permute_impl.hpp | 34 +- .../profile_batched_gemm_gemm_impl.hpp | 6 +- .../profiler/profile_batched_gemm_impl.hpp | 21 +- .../profile_batched_gemm_reduce_impl.hpp | 6 +- ...profile_batched_gemm_softmax_gemm_impl.hpp | 43 +- ...batched_gemm_softmax_gemm_permute_impl.hpp | 31 +- .../profile_batchnorm_backward_impl.hpp | 14 +- .../profile_batchnorm_forward_impl.hpp | 14 +- .../profiler/profile_batchnorm_infer_impl.hpp | 14 +- .../profiler/profile_contraction_impl.hpp | 44 +- .../profiler/profile_conv_bwd_data_impl.hpp | 17 +- .../profiler/profile_conv_fwd_impl.hpp | 17 +- .../profile_conv_tensor_rearrange_impl.hpp | 20 +- .../profile_elementwise_layernorm_impl.hpp | 13 +- .../profiler/profile_gemm_ab_scale_impl.hpp | 5 - ...ofile_gemm_add_relu_add_layernorm_impl.hpp | 32 +- .../profiler/profile_gemm_add_relu_impl.hpp | 29 +- .../profile_gemm_bias_add_reduce_impl.hpp | 8 +- .../profile_gemm_blockscale_wp_impl.hpp | 34 +- .../profiler/profile_gemm_fastgelu_impl.hpp | 29 +- .../include/profiler/profile_gemm_impl.hpp | 8 +- .../profiler/profile_gemm_multi_abd_impl.hpp | 424 ++++++++ .../profile_gemm_multiply_add_impl.hpp | 34 +- .../profile_gemm_quantization_impl.hpp | 4 +- .../profiler/profile_gemm_reduce_impl.hpp | 25 +- .../profiler/profile_gemm_splitk_impl.hpp | 25 +- .../profiler/profile_gemm_streamk_impl.hpp | 8 +- .../profile_gemm_universal_batched_impl.hpp | 6 +- .../profiler/profile_gemm_universal_impl.hpp | 8 +- ...profile_gemm_universal_preshuffle_impl.hpp | 8 +- .../profile_gemm_universal_reduce_impl.hpp | 8 +- .../profile_gemm_universal_streamk_impl.hpp | 8 +- .../profile_grouped_conv_bwd_data_impl.hpp | 40 +- .../profile_grouped_conv_bwd_weight_impl.hpp | 15 +- ...grouped_conv_fwd_bias_bnorm_clamp_impl.hpp | 24 +- ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 26 +- .../profile_grouped_conv_fwd_impl.hpp | 18 +- .../profiler/profile_grouped_gemm_impl.hpp | 24 +- .../profile_groupnorm_bwd_data_impl.hpp | 14 +- .../profiler/profile_groupnorm_fwd_impl.hpp | 14 +- .../profile_layernorm_bwd_data_impl.hpp | 14 +- .../profile_layernorm_bwd_gamma_beta_impl.hpp | 14 +- .../profiler/profile_layernorm_fwd_impl.hpp | 14 +- .../profiler/profile_max_pool2d_bwd_impl.hpp | 18 +- .../profiler/profile_max_pool3d_bwd_impl.hpp | 17 +- .../profiler/profile_permute_scale_impl.hpp | 11 +- .../profiler/profile_pool2d_fwd_impl.hpp | 18 +- .../profiler/profile_pool3d_fwd_impl.hpp | 18 +- .../include/profiler/profile_reduce_impl.hpp | 22 +- .../include/profiler/profile_softmax_impl.hpp | 19 +- profiler/src/CMakeLists.txt | 2 + profiler/src/profile_gemm_multi_abd.cpp | 180 ++++ profiler/src/profile_gemm_multiply_add.cpp | 31 +- pyproject.toml | 4 +- python/ck4inductor/__init__.py | 19 + script/convert_miopen_driver_to_profiler.py | 128 +-- script/launch_tests.sh | 4 +- test/CMakeLists.txt | 3 + test/batched_gemm/test_batched_gemm_wmma.cpp | 239 +++-- test/batched_gemm/test_batched_gemm_xdl.cpp | 239 +++-- .../test_batched_gemm_multi_d_dl.cpp | 34 +- ...est_batched_gemm_softmax_gemm_fp16_xdl.cpp | 20 + .../test_batched_gemm_softmax_gemm_util.hpp | 30 +- ...emm_bias_softmax_gemm_permute_bf16_xdl.cpp | 19 + ...emm_bias_softmax_gemm_permute_fp16_xdl.cpp | 19 + ...ed_gemm_bias_softmax_gemm_permute_util.hpp | 24 +- ...hed_gemm_softmax_gemm_permute_bf16_xdl.cpp | 19 + ...hed_gemm_softmax_gemm_permute_fp16_xdl.cpp | 20 + ...batched_gemm_softmax_gemm_permute_util.hpp | 23 +- test/batchnorm/batchnorm_bwd_rank_4.cpp | 78 +- test/batchnorm/batchnorm_fwd_rank_4.cpp | 34 +- test/batchnorm/batchnorm_infer_rank_4.cpp | 67 +- test/ck_tile/CMakeLists.txt | 1 + .../test_batched_gemm_ut_cases.inc | 3 +- test/ck_tile/elementwise/CMakeLists.txt | 5 +- .../elementwise/test_elementwise_1d.cpp | 77 +- .../epilogue/test_cshuffle_epilogue.cpp | 45 +- .../epilogue/test_cshuffle_epilogue_util.hpp | 48 +- test/ck_tile/fmha/CMakeLists.txt | 8 + test/ck_tile/fmha/test_fmha_bwd.inc | 3 + test/ck_tile/fmha/test_fmha_bwd_bf16.cpp | 2 +- test/ck_tile/fmha/test_fmha_bwd_fp16.cpp | 2 +- test/ck_tile/fmha/test_fmha_bwd_fp32.cpp | 20 + test/ck_tile/fmha/test_fmha_fwd.inc | 480 ++++++++- test/ck_tile/fmha/test_fmha_fwd_fp32.cpp | 39 + test/ck_tile/fmha/test_fmha_fwd_fp8.cpp | 13 +- test/ck_tile/gemm/CMakeLists.txt | 6 + .../gemm/test_gemm_pipeline_basic_bf16.cpp | 9 +- .../gemm/test_gemm_pipeline_basic_bf8.cpp | 10 +- .../gemm/test_gemm_pipeline_basic_fp16.cpp | 11 +- .../gemm/test_gemm_pipeline_basic_fp8.cpp | 10 +- .../test_gemm_pipeline_basic_run_test.inc | 4 +- .../gemm/test_gemm_pipeline_comp_async.cpp | 17 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 11 +- .../test_gemm_pipeline_smoke_run_test.inc | 57 +- .../gemm/test_gemm_pipeline_smoke_util.hpp | 27 + .../test_gemm_pipeline_universal_bf16.cpp | 9 +- .../gemm/test_gemm_pipeline_universal_bf8.cpp | 10 +- .../test_gemm_pipeline_universal_fp16.cpp | 9 +- .../gemm/test_gemm_pipeline_universal_fp8.cpp | 10 +- .../test_gemm_pipeline_universal_int8.cpp | 15 +- .../test_gemm_pipeline_universal_pk_int4.cpp | 15 +- .../test_gemm_pipeline_universal_run_test.inc | 2 +- .../gemm/test_gemm_pipeline_ut_cases.inc | 51 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 15 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 11 +- .../test_gemm_aquant_basic_bf8.cpp | 6 - .../test_gemm_aquant_basic_fp8.cpp | 6 - .../test_gemm_aquant_basic_i4bf8.cpp | 6 - .../test_gemm_aquant_basic_i4f32bf8.cpp | 6 - .../test_gemm_aquant_basic_i4f32fp8.cpp | 6 - .../test_gemm_aquant_basic_i4fp8.cpp | 6 - .../test_gemm_aquant_utils.hpp | 243 ----- .../gemm_block_scale/test_gemm_quant_base.hpp | 198 ++++ .../test_gemm_quant_fixtures.hpp | 975 ++++++++++++++++++ .../test_gemm_quant_typed.cpp | 73 ++ .../test_gemm_quant_ut_cases.inc | 33 + .../test_run_gemm_aquant_example.inc | 616 ----------- test/ck_tile/gemm_multi_abd/CMakeLists.txt | 8 +- .../test_gemm_multi_abd_cshuffle.cpp | 8 +- .../test_gemm_multi_abd_default2d.cpp | 10 +- .../test_gemm_multi_abd_ut_cases_cshuffle.inc | 100 -- .../gemm_streamk/test_gemm_streamk_types.hpp | 2 + .../test_gemm_pipeline_kernel_types.hpp | 25 +- .../test_gemm_pipeline_ut_cases.inc | 8 +- .../test_gemm_pipeline_util.hpp | 49 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 22 +- .../grouped_gemm_multi_d/CMakeLists.txt | 9 + .../test_grouped_gemm_multi_d.cpp | 78 ++ .../test_grouped_gemm_multi_d_ut_cases.inc | 91 ++ .../test_grouped_gemm_multi_d_util.hpp | 558 ++++++++++ .../test_grouped_gemm_preshuffle_util.hpp | 18 +- test/contraction/test_contraction_xdl.cpp | 27 +- .../test_conv_tensor_rearrange.cpp | 39 +- .../test_conv_tensor_rearrange_interface.cpp | 8 +- test/convnd_bwd_data/convnd_bwd_data_xdl.cpp | 31 +- test/convnd_fwd/convnd_fwd_xdl.cpp | 30 +- test/data_type/test_bf8_fnuz.cpp | 20 +- test/data_type/test_fp8_fnuz.cpp | 20 +- .../test_elementwise_layernorm_fp16.cpp | 28 +- test/gemm/gemm_bf16.cpp | 2 +- test/gemm/gemm_fp16.cpp | 2 +- test/gemm/gemm_fp32.cpp | 2 +- test/gemm/gemm_fp64.cpp | 2 +- test/gemm/gemm_int8.cpp | 2 +- test/gemm/gemm_standalone_xdl_fp16.cpp | 17 +- test/gemm/gemm_util.hpp | 38 + test/gemm/run_gemm_test.inc | 53 +- test/gemm_multi_abd/CMakeLists.txt | 9 + test/gemm_multi_abd/test_gemm_common.hpp | 73 ++ .../test_gemm_multi_abd_wmma.cpp | 154 +++ .../test_gemm_multi_abd_xdl.cpp | 154 +++ test/gemm_reduce/gemm_reduce_fp16_xdl.cpp | 23 +- test/gemm_split_k/test_gemm_splitk_util.hpp | 12 +- test/gemm_split_k/test_gemm_splitk_xdl.cpp | 20 + .../test_gemm_universal_util.hpp | 13 +- .../test_gemm_universal_wmma_bf16.cpp | 25 +- .../test_gemm_universal_wmma_fp16.cpp | 25 +- .../test_gemm_universal_wmma_fp8.cpp | 19 +- .../test_gemm_universal_xdl_bf16.cpp | 23 +- .../test_gemm_universal_xdl_fp16.cpp | 23 +- .../test_gemm_universal_xdl_fp8.cpp | 27 +- .../test_gemm_universal_streamk_util.hpp | 13 +- .../test_gemm_universal_streamk_xdl_bf16.cpp | 19 + .../test_gemm_universal_streamk_xdl_fp16.cpp | 20 + .../test_gemm_universal_streamk_xdl_fp8.cpp | 19 + .../test_grouped_convnd_bwd_data_wmma.cpp | 45 +- .../test_grouped_convnd_bwd_data_xdl.cpp | 43 +- ...t_grouped_conv_bwd_weight_xdl_bilinear.cpp | 43 +- .../test_grouped_convnd_bwd_weight.cpp | 30 +- ...rouped_convnd_bwd_weight_interface_xdl.cpp | 53 +- .../test_grouped_convnd_fwd.cpp | 51 +- ...st_grouped_convnd_fwd_bias_bnorm_clamp.cpp | 30 +- .../test_grouped_convnd_fwd_bias_clamp.cpp | 31 +- ...uped_convnd_fwd_bias_clamp_large_cases.cpp | 31 +- .../test_grouped_convnd_fwd_clamp.cpp | 53 +- ...grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp | 33 +- .../test_grouped_convnd_fwd_gk_bias_clamp.cpp | 24 +- .../test_grouped_gemm_splitk_xdl.cpp | 19 + .../test_grouped_gemm_ut_cases.inc | 4 +- test/grouped_gemm/test_grouped_gemm_util.hpp | 16 +- .../magic_number_division.cpp | 27 +- .../test_groupnorm_bwd_data_fp32.cpp | 28 +- .../test_layernorm2d_bwd_data_fp32.cpp | 43 +- .../test_layernorm2d_bwd_gamma_beta_fp32.cpp | 29 +- .../test_groupnorm_fwd_fp16.cpp | 43 +- .../test_groupnorm_fwd_fp32.cpp | 43 +- .../test_layernorm2d_fwd_fp16.cpp | 29 +- .../test_layernorm2d_fwd_fp32.cpp | 29 +- .../test_layernorm4d_fwd_fp16.cpp | 29 +- test/pool/test_avg_pool2d_bwd.cpp | 30 +- test/pool/test_avg_pool2d_fwd.cpp | 29 +- test/pool/test_avg_pool3d_bwd.cpp | 29 +- test/pool/test_avg_pool3d_fwd.cpp | 29 +- test/pool/test_max_pool2d_bwd.cpp | 30 +- test/pool/test_max_pool2d_fwd.cpp | 30 +- test/pool/test_max_pool3d_bwd.cpp | 30 +- test/pool/test_max_pool3d_fwd.cpp | 35 +- test/reduce/reduce_no_index.cpp | 30 +- test/reduce/reduce_with_index.cpp | 30 +- test/softmax/test_softmax_ut_cases.inc | 17 + test/softmax/test_softmax_util.hpp | 148 +-- 653 files changed, 25972 insertions(+), 6782 deletions(-) create mode 100644 .github/workflows/therock-test-component.yml create mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp create mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp create mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp create mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp create mode 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp create mode 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp create mode 100644 example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage.cpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp mode change 100644 => 100755 example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp mode change 100644 => 100755 example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc create mode 100644 example/ck_tile/40_streamk_gemm/README.md create mode 100644 example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp delete mode 100644 include/ck/library/utility/validation_common.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp create mode 100644 include/ck_tile/ops/common/load_interleaved_pk_type.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp delete mode 100644 include/ck_tile/ops/gemm_group_quant.hpp create mode 100644 include/ck_tile/ops/gemm_quant.hpp create mode 100755 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/block/block_universal_gemm_as_aquant_bs_cr.hpp (95%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/block/block_universal_gemm_as_bs_bquant_cr.hpp (93%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/kernel/gemm_quant_kernel.hpp (79%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/kernel/grouped_gemm_quant_kernel.hpp (98%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp (100%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp (98%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp (99%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp (100%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp (98%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp (99%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_group_quant_utils.hpp (100%) rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/gemm_quant_pipeline_problem.hpp (86%) create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp rename include/ck_tile/ops/{gemm_group_quant => gemm_quant}/pipeline/tile_gemm_quant_traits.hpp (76%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_multi_abd_impl.hpp create mode 100644 profiler/src/profile_gemm_multi_abd.cpp create mode 100644 test/ck_tile/fmha/test_fmha_bwd_fp32.cpp create mode 100644 test/ck_tile/fmha/test_fmha_fwd_fp32.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp delete mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc delete mode 100644 test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc create mode 100644 test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt create mode 100644 test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp create mode 100644 test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_ut_cases.inc create mode 100644 test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp create mode 100644 test/gemm_multi_abd/CMakeLists.txt create mode 100644 test/gemm_multi_abd/test_gemm_common.hpp create mode 100644 test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp create mode 100644 test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1d0f7df3c6..af36f492ba 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd # Documentation files docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD *.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 557afe2d84..cc66fdbfe8 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -42,6 +42,24 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: file=sys.stderr, ) return None + +GITHUB_WORKFLOWS_CI_PATTERNS = [ + "therock*", +] + +def is_path_workflow_file_related_to_ci(path: str) -> bool: + return any( + fnmatch.fnmatch(path, ".github/workflows/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) or any( + fnmatch.fnmatch(path, ".github/scripts/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) + +def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool: + if paths is None: + return False + return any(is_path_workflow_file_related_to_ci(p) for p in paths) # Paths matching any of these patterns are considered to have no influence over # build or test workflows so any related jobs can be skipped if all paths @@ -82,12 +100,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: ) other_paths = paths_set - github_workflows_paths + related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths) contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) print("should_ci_run_given_modified_paths findings:") print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") - if contains_other_non_skippable_files: + if related_to_ci: + print("Enabling build jobs since a related workflow file was modified") + return True + elif contains_other_non_skippable_files: print("Enabling TheRock CI jobs since a non-skippable path was modified") return True else: diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 7db124d2a1..695fb1d913 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -27,30 +27,35 @@ jobs: TEATIME_FORCE_INTERACTIVE: 0 AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini steps: + - name: "Checking out repository for rocm-libraries" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/rocm-libraries" + - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: "composable_kernel" - name: Checkout TheRock repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200 path: "TheRock" - name: Runner Health Settings run: | - df -h - cmake --version - echo "Installed Python versions:" - ls -d /opt/python - echo "python: $(which python), python3: $(which python3)" - echo "Git version: $(git --version)" - git config --global --add safe.directory $PWD - git config fetch.parallel 10 + ./TheRock/build_tools/health_status.py - name: Fetch sources run: | - ./TheRock/build_tools/fetch_sources.py --jobs 12 + ./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks + + - name: Patch rocm-libraries + run: | + git config --global --add safe.directory '*' + git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | @@ -92,32 +97,14 @@ jobs: aws-region: us-east-2 role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external - - name: Create Logs index Files and upload logs + - name: Post Build Upload if: always() run: | - python3 TheRock/build_tools/github_actions/create_log_index.py \ - --build-dir=TheRock/build \ - --amdgpu-family=${{ env.AMDGPU_FAMILIES }} - - python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ - --build-dir=TheRock/build \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} - - - name: Upload artifacts - run: | - python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build - - - name: Add Links to Job Summary - if: always() - run: | - python TheRock/build_tools/github_actions/upload_build_summary.py \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build + --build-dir TheRock/build \ + --upload therock-test-linux: name: "Test" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 3232652b6b..40a3b0bec8 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -56,7 +56,14 @@ jobs: uses: ./.github/workflows/therock-ci-linux.yml secrets: inherit with: - cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + cmake_options: >- + -DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON + -DTHEROCK_ENABLE_MIOPEN=ON + -DTHEROCK_ENABLE_ALL=OFF + -DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON + -DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel + -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON + -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" test_runs_on: "linux-mi325-1gpu-ossci-rocm" diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml new file mode 100644 index 0000000000..674e93c1de --- /dev/null +++ b/.github/workflows/therock-test-component.yml @@ -0,0 +1,71 @@ +name: Test component + +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + component: + type: string + + +permissions: + contents: read + +jobs: + test_component: + name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})' + runs-on: ${{ inputs.test_runs_on }} + container: + image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }} + options: --ipc host + --group-add video + --device /dev/kfd + --device /dev/dri + --group-add 992 + --env-file /etc/podinfo/gha-gpu-isolation-settings + strategy: + fail-fast: false + matrix: + # The shard array is based on "total_shards" from "fetch_test_configurations.py" + # The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards) + shard: ${{ fromJSON(inputs.component).shard_arr }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}" + OUTPUT_ARTIFACTS_DIR: "./build" + THEROCK_BIN_DIR: "./build/bin" + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + steps: + - name: Checkout Repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} + + - name: Test + timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }} + env: + SHARD_INDEX: ${{ matrix.shard }} + TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }} + run: | + ${{ fromJSON(inputs.component).test_script }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 37ddd399ad..54e068eb3d 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -37,41 +37,17 @@ jobs: test_components: name: 'Test ${{ matrix.components.job_name }}' - runs-on: ${{ inputs.test_runs_on }} - needs: configure_test_matrix + needs: [configure_test_matrix] # skip tests if no test matrix to run if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} strategy: fail-fast: false matrix: components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} - defaults: - run: - shell: bash - env: - VENV_DIR: ${{ github.workspace }}/.venv - ARTIFACT_RUN_ID: "${{ github.run_id }}" - OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build - THEROCK_BIN_DIR: "./build/bin" - steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: "ROCm/TheRock" - - - name: Run setup test environment workflow - uses: './.github/actions/setup_test_environment' - with: - ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} - OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} - VENV_DIR: ${{ env.VENV_DIR }} - FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} - PLATFORM: ${{ inputs.platform }} - IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} - - - name: Test - timeout-minutes: ${{ matrix.components.timeout_minutes }} - run: | - if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi - ${{ matrix.components.test_script }} + uses: './.github/workflows/therock-test-component.yml' + with: + artifact_run_id: ${{ github.run_id }} + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: ${{ inputs.platform }} + component: ${{ toJSON(matrix.components) }} diff --git a/CHANGELOG.md b/CHANGELOG.md index dafe1b5c87..9aadc3dc54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added a compute async pipeline in the CK TILE universal GEMM on gfx950 +* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. +* Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). @@ -30,6 +34,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. +* Added support for f32 to FMHA (fwd/bwd). +* Added tensor-wise quantization for CK_TILE GEMM. ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index 26d91fe6d8..f4d3a83c34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,9 @@ rocm_check_target_ids(SUPPORTED_GPU_TARGETS message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") +# Cache SUPPORTED_GPU_TARGETS for debug +set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") + if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) @@ -339,6 +342,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) +option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -352,6 +356,11 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if (ENABLE_JSON_DUMP) + add_compile_definitions(CK_ENABLE_JSON_DUMP) + message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Dockerfile b/Dockerfile index 6f5cd0115d..07327442fe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,23 @@ + FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.4.1 +ARG ROCMVERSION=7.0.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive # Add rocm repository RUN set -xe && \ - apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ - curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN if [ "$ROCMVERSION" != "6.5" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ - fi - -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ - amdgpu-install -y --usecase=rocm --no-dkms +RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ + apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ + apt update && \ + apt install python3-setuptools python3-wheel -y && \ + apt install rocm-dev -y ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -45,7 +41,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libelf-dev \ libnuma-dev \ libpthread-stubs0-dev \ - llvm-amdgpu \ mpich \ net-tools \ pkg-config \ @@ -61,17 +56,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- zip \ libzstd-dev \ openssh-server \ - clang-format-12 \ clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ rm -rf amdgpu-install* && \ -# Remove unnecessary rocm components that take a lot of space - apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt - #Install latest ccache -RUN git clone https://github.com/ccache/ccache.git && \ + git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 0306057e45..47bd8294b6 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index efe08a7d41..bb904052bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -53,7 +53,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -476,7 +476,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 20, unit: 'HOURS') { @@ -538,7 +538,7 @@ def Build_CK(Map conf=[:]){ def image def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -728,7 +728,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -836,7 +836,7 @@ def run_aiter_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -852,13 +852,14 @@ def run_aiter_tests(Map conf=[:]){ } withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 2, unit: 'HOURS'){ + timeout(time: 5, unit: 'HOURS'){ try{ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -894,7 +895,7 @@ def run_pytorch_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -930,7 +931,8 @@ def run_pytorch_tests(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true + 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true @@ -957,8 +959,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.4.1', - description: 'Specify which ROCM version to use: 6.4.1 (default).') + defaultValue: '7.0.1', + description: 'Specify which ROCM version to use: 7.0.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -1037,8 +1039,8 @@ pipeline { description: "Build CK and run tests on gfx942 (default: ON)") booleanParam( name: "BUILD_GFX950", - defaultValue: false, - description: "Build CK and run tests on gfx950 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx950 (default: ON)") booleanParam( name: "BUILD_GFX10", defaultValue: true, @@ -1125,16 +1127,16 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1155,16 +1157,17 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \ + \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1290,7 +1293,7 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_PREFIX_PATH=/opt/rocm ../codegen && \ make -j64 check""" } steps{ @@ -1350,7 +1353,6 @@ pipeline { } agent{ label rocmnode("gfx950") } environment{ - def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0" setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \ make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ @@ -1358,7 +1360,7 @@ pipeline { example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """ } steps{ - buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() } } @@ -1566,7 +1568,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1631,7 +1633,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB}:ck_ub24.04_rocm7.0.1") } cleanWs() } @@ -1657,13 +1659,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx1101") + stage("Build CK and run Tests on gfx11") { when { beforeAgent true expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx1101") } + agent{ label 'miopen && (gfx1101 || gfx1100)' } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2b2e6e2949..80429a781b 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -12,6 +12,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -27,7 +28,7 @@ add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers) +target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 414683ffdf..66a0d98238 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,7 +37,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 4, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance1; diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp index ecd3b7be5d..59c059d014 100644 --- a/example/01_gemm/gemm_xdl_fp16_v2.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,13 +33,13 @@ using DeviceGemmInstance = 2, 256, 256, 256, 32, 8, 4, - 32, 32, - 4, 4, + 16, 16, + 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v1>; // clang-format on diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index d149fd88f1..d5c42558c4 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; #else - < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; using ADataType = float; using BDataType = float; using CDataType = float; @@ -185,7 +185,6 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -209,8 +208,7 @@ int main(int argc, char* argv[]) return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index d8672f6a0c..76a30657f0 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -29,7 +29,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_WaveletM // ######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 4>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 08e2b8c15f..7fb0c1e812 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -2,7 +2,6 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/validation_common.hpp" // use macro to minimize code change #ifndef EXAMPLE_WITH_COMPUTE_DATATYPE @@ -29,11 +28,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -59,17 +58,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - try - { - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - } - catch(const std::runtime_error& e) - { - std::cerr << "Error: " << e.what() << std::endl; - return false; - } - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index bffa2e5640..992e7c19c8 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -174,6 +174,9 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + const auto StrideD = std::is_same::value + ? d_m_n.mDesc.GetStrides()[0] + : d_m_n.mDesc.GetStrides()[1]; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; @@ -221,7 +224,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{0}, + std::array{static_cast(StrideD)}, StrideE, a_element_op, b_element_op, diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index e630f67837..4e98bf3034 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,7 +32,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm int { + if constexpr(std::is_same_v) + { + return static_cast(tensor.GetStrides()[0]); + } + else + { + return static_cast(tensor.GetStrides()[1]); + } + }; + + if(StrideA <= 0) + StrideA = fetch_leading_stride(a_m_k, ALayout{}); + if(StrideB <= 0) + StrideB = fetch_leading_stride(b_k_n, BLayout{}); + if(StrideD0 <= 0) + StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{}); + if(StrideD1 <= 0) + StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{}); + if(StrideE <= 0) + StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{}); + switch(config.init_method) { case 0: break; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 036f288d0a..7142521c55 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); problem_size = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index c4e7068499..4b290d02a2 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -23,7 +23,7 @@ using RsGlobalReduceOp = static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off template diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index eb8b5c76d3..9e125c4e5d 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -100,13 +100,13 @@ int main(int argc, char* argv[]) const std::array reduceDims = {3, 4}; // const std::array invariantDims = {0, 1, 2}; - const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + std::vector inLengths_1 = {64, 320, 80, 4, 128}; // input lengths of the second reduction, which is also the output lengths of the first // reduction - const std::vector inLengths_2 = {64, 320, 80, 4}; + std::vector inLengths_2 = {64, 320, 80, 4}; - const std::vector outLengths = {64, 320, 80}; + std::vector outLengths = {64, 320, 80}; if(argc == 1) { @@ -114,11 +114,26 @@ int main(int argc, char* argv[]) init_method = 2; time_kernel = true; } - else if(argc == 4) + else if((argc == 4) || (argc == 9)) { do_verify = static_cast(argv[1]); init_method = atoi(argv[2]); time_kernel = static_cast(atoi(argv[3])); + if(argc == 9) + { + inLengths_1[0] = atoi(argv[4]); + inLengths_1[1] = atoi(argv[5]); + inLengths_1[2] = atoi(argv[6]); + inLengths_1[3] = atoi(argv[7]); + inLengths_1[4] = atoi(argv[8]); + inLengths_2[0] = inLengths_1[0]; + inLengths_2[1] = inLengths_1[1]; + inLengths_2[2] = inLengths_1[2]; + inLengths_2[3] = inLengths_1[3]; + outLengths[0] = inLengths_1[0]; + outLengths[1] = inLengths_1[1]; + outLengths[2] = inLengths_1[2]; + } } else { diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 3ce08fd2af..abbf1b29f7 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -78,12 +78,12 @@ bool pool_test(bool do_verification, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 2585072dfe..5291f5ce69 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -115,12 +115,14 @@ int main() if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1_uz})); + std::vector({stride, 1_uz}), + layout); } else { return HostTensorDescriptor(std::vector({row, col}), - std::vector({1_uz, stride})); + std::vector({1_uz, stride}), + layout); } }; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp index 680cee1f81..ac64a468a4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 90a12bc1dd..85ea8c2f2c 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index 28b0fcd0ce..fb047ae364 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 13da444051..4a701e7792 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -137,11 +137,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); } }; diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index ce9f9b7032..ae5e3f36ad 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>; // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { - // temp disable on gfx11 & gfx12 - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + // temp disable on gfx11 + if(ck::is_gfx11_supported()) { return 0; } diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index fa4482a984..716d36b487 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -69,7 +69,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 89a581e865..2996d87b28 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -47,10 +47,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -68,7 +68,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 741512bf00..c93a2051d2 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -59,11 +59,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 3582bc5e33..ac34ed5b8a 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -137,11 +137,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 778be8ffd7..9939429a08 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -64,11 +64,13 @@ bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionCo if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index 420a7cf74f..4f4003809b 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1 using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -189,7 +191,7 @@ int run_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -173,7 +175,7 @@ int run_contraction_scale_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 #include @@ -18,6 +18,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -304,10 +307,10 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Bypass{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); ck::index_t M_ = ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); @@ -416,9 +419,9 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index f556be887f..c4cb7a13a2 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -300,11 +303,11 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; @@ -396,7 +399,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -54,7 +57,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -345,7 +348,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + #include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc" int main(int argc, char* argv[]) diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 255a0cddaf..7a03e9cacf 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -110,11 +110,13 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc index 8ab47c2925..cea18459f4 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc @@ -62,17 +62,19 @@ int run(int argc, char* argv[]) std::vector b1_g_o_n_lengths{G, O, N}; #ifdef CK_MHA_USE_RCCR_LAYOUT std::vector b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N] + auto b1_layout = Row{}; #else std::vector b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O] + auto b1_layout = Col{}; #endif std::vector c_g_m_o_lengths{G, M, O}; std::vector c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O] - Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides); - Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides); - Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides); - Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides); - Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides); + Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides, Row{}); + Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides, Row{}); + Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides, b1_layout); + Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); + Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 5794924294..7738a6b6d4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -100,11 +100,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -129,7 +129,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 1514fc48b3..aa2a6b3b42 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -111,12 +111,14 @@ int run(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); + std::vector({batch_stride, stride, 1}), + layout); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); + std::vector({batch_stride, 1, stride}), + layout); } }; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 2b02069e65..6175f0b5be 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +90,11 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Bypass{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Bypass{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Bypass{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index e0ccb6dad1..db13e3b963 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +92,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 0ad031cc71..1e4b52d4cf 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -113,11 +117,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -191,7 +214,7 @@ int run(int argc, char* argv[]) head_num * 2 * head_dim, head_dim, 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] - Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides, Bypass{}); // merge kv into a packed pointer send to device b0_gs_ns_ks.ForEach( [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index c693995140..874d987a1d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -63,6 +67,19 @@ int run(int argc, char* argv[]) std::size_t flop = 0, num_byte = 0; + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; std::cout << "group count " << group_count << ". printing first 4 groups\n"; for(std::size_t i = 0; i < group_count; i++) { @@ -113,10 +130,14 @@ int run(int argc, char* argv[]) {}}); // acc1_biases_gs_ms_os_strides // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks(f_host_tensor_descriptor( + b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns(f_host_tensor_descriptor( + b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_device_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); int Batch = G0 * G1; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; @@ -252,7 +273,8 @@ int run(int argc, char* argv[]) Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_host_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); // permute a_gs_ms_ks.ForEach([&](auto& self, auto idx) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 7ac29f33ca..1c2a26d916 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index fb9b1b0bd7..76f3ee756c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 2cb69380e5..86754927ed 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -108,11 +112,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -186,7 +209,7 @@ int run(int argc, char* argv[]) head_num * 3 * head_dim, head_dim, 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] - Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides, Bypass{}); // merge qkv into a packed pointer send to device a_gs_ms_ks.ForEach( [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp index 7ceb1d09ef..1843198933 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp index b5aeff65d6..1e4398b9f6 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp index cb84f2a416..d5acde139a 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp index 2ab8f77dc4..bb3c23f060 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v2, ReduceDataType>; // clang-format on diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index 26a03f289d..a1b952259f 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -50,14 +50,14 @@ template<> struct emb_kernel { using kernel_type = DeviceInsta // clang-format on -int main() +int main(int argc, char* argv[]) { bool time_kernel = true; - constexpr auto num_rows = 65536; - constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; - // constexpr auto dims = ck::Sequence<256, 512>{}; - constexpr auto index_length = 2048; + ck::index_t num_rows = 65536; + constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; + ck::index_t index_length = 2048; + ck::index_t dim_mask = 0xffff; constexpr AccDataType epsilon = 1e-4; auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); }; @@ -73,121 +73,140 @@ int main() BetaDataType, AccDataType, OutType>; - + if(argc == 1) + { + // Use default value + } + else if(argc == 4) + { + num_rows = atoi(argv[1]); + dim_mask = strtol(argv[2], nullptr, 0); + index_length = atoi(argv[3]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1-3: num_rows dim_mask index_length" << std::endl; + } ck::static_for<0, dims.Size(), 1>{}([&](auto I) { - std::srand(std::time(nullptr)); - constexpr auto current_dim = dims.At(I); - Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); - - Tensor index_a(f_host_tensor_desc_1d(index_length)); - Tensor index_b(f_host_tensor_desc_1d(index_length)); - Tensor index_c(f_host_tensor_desc_1d(index_length)); - - Tensor gamma(f_host_tensor_desc_1d(current_dim)); - Tensor beta(f_host_tensor_desc_1d(current_dim)); - - Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); - - emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); - DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); - DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); - - DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); - DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); - DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); - - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - - DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); - - emb_a_dev.ToDevice(emb_a.mData.data()); - emb_b_dev.ToDevice(emb_b.mData.data()); - emb_c_dev.ToDevice(emb_c.mData.data()); - - index_a_dev.ToDevice(index_a.mData.data()); - index_b_dev.ToDevice(index_b.mData.data()); - index_c_dev.ToDevice(index_c.mData.data()); - - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = typename emb_kernel::kernel_type{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - out_dev.GetDeviceBuffer(), - {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, - {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - current_dim, - index_length, - epsilon, - EmbElementwiseOperation{}); - std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() - << std::endl - << std::flush; - - bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); - - if(!is_supported) + if(dim_mask & (1 << I.value)) { - std::cout << "Runtime parameters are not supported" << std::endl; - return; + std::srand(std::time(nullptr)); + constexpr auto current_dim = dims.At(I); + Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); + + Tensor index_a(f_host_tensor_desc_1d(index_length)); + Tensor index_b(f_host_tensor_desc_1d(index_length)); + Tensor index_c(f_host_tensor_desc_1d(index_length)); + + Tensor gamma(f_host_tensor_desc_1d(current_dim)); + Tensor beta(f_host_tensor_desc_1d(current_dim)); + + Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); + + emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); + DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); + DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); + + DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); + DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); + DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); + + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + + DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); + + emb_a_dev.ToDevice(emb_a.mData.data()); + emb_b_dev.ToDevice(emb_b.mData.data()); + emb_c_dev.ToDevice(emb_c.mData.data()); + + index_a_dev.ToDevice(index_a.mData.data()); + index_b_dev.ToDevice(index_b.mData.data()); + index_c_dev.ToDevice(index_c.mData.data()); + + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = typename emb_kernel::kernel_type{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + out_dev.GetDeviceBuffer(), + {ck::type_convert(emb_a_dev.GetDeviceBuffer()), + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + {ck::type_convert(index_a_dev.GetDeviceBuffer()), + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + current_dim, + index_length, + epsilon, + EmbElementwiseOperation{}); + std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() + << std::endl + << std::flush; + + bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); + + if(!is_supported) + { + std::cout << "Runtime parameters are not supported" << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float time_ms = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(out, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + num_rows, + current_dim, + index_length, + epsilon); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + out_dev.FromDevice(out_from_dev.mData.data()); + pass &= + ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); + } + + double total_read = current_dim * index_length * 3 * sizeof(EmbType) + + current_dim * sizeof(GammaDataType) + + current_dim * sizeof(BetaDataType); + double total_write = current_dim * index_length * sizeof(OutType); + double gbps = (total_read + total_write) / time_ms / 1e6; + + std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms + << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl + << std::flush; } - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); - ReferenceInstance ref; - auto ref_argument = ref.MakeArgument(out, - emb_a, - emb_b, - emb_c, - index_a, - index_b, - index_c, - gamma, - beta, - num_rows, - current_dim, - index_length, - epsilon); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - out_dev.FromDevice(out_from_dev.mData.data()); - pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); - } - - double total_read = current_dim * index_length * 3 * sizeof(EmbType) + - current_dim * sizeof(GammaDataType) + - current_dim * sizeof(BetaDataType); - double total_write = current_dim * index_length * sizeof(OutType); - double gbps = (total_read + total_write) / time_ms / 1e6; - - std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms - << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl - << std::flush; }); return 0; diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 904ff761fd..4934f74393 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -321,11 +321,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0f0b120cbc..80d56cd781 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -206,7 +206,8 @@ int run_grouped_conv_bwd_data_bias_relu_example(int argc, char* argv[]) 1, // c 0, // hi 0 // wi - }); + }, + ctc::GNCHW{}); // input image: GNHWC const auto in_g_n_c_wis_desc = diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 30e0791ebf..3c089688cf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -214,7 +214,8 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_ 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto requant_scale_g_k_desc = bias_g_k_desc; diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 32fd435e00..ed7886e76b 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -201,7 +201,8 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 362d90b4c1..12fdf425bf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -203,7 +203,8 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme 1, // k 0, // ho 0 // wo - }); + }, + RequantScaleLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index ebba88cf41..b5e9686260 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 2>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 nchw = {16, 128, 32, 64}; + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + nchw[0] = std::stoi(argv[1]); + nchw[1] = std::stoi(argv[2]); + nchw[2] = std::stoi(argv[3]); + nchw[3] = std::stoi(argv[4]); + } + else + { + std::cerr << "arg1 to 4: N, C, H, W" << std::endl; + + return 1; + } + std::array ab_lengths; std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), static_cast(nchw[2] * nchw[3]), @@ -73,11 +93,11 @@ int main(int argc, char* argv[]) 1}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 2> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 2> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -134,7 +154,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 9e92543252..2d689648f2 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple @@ -72,9 +74,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[3])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -117,7 +119,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 88c23b5f40..6e70a306d3 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -23,6 +23,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -137,7 +139,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 1185b5a3ca..632d88e88a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -128,7 +131,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 28a3dbc44c..bd54f1c19c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; @@ -139,7 +141,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index 14d1d96165..9621d591a9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -127,7 +130,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 2583f1cb5e..be4014f636 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -78,13 +81,13 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 3> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 3> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; Tensor& a2 = as[2]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; float gamma = 4.f; @@ -149,7 +152,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 51006e676b..8064809123 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -98,8 +98,23 @@ int main(int argc, char* argv[]) exit(0); } - ck::index_t M = 48 * 256; - ck::index_t N = 1024; + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + if(argc == 1) + { + // use default case + } + else if(argc == 3) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + } + else + { + std::cerr << "arg1 to 2: M, N" << std::endl; + return 1; + } + ck::index_t Stride = N; auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index 56417b101d..4d73f0c35f 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" @@ -31,7 +31,7 @@ using DeviceOpInstance = ck::tensor_operation::device:: //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>; + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + ProblemSize ps = + problem_size; // make mutable copy because default stride values of 0 need to be updated + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); @@ -123,7 +131,16 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + if(std::is_same_v, ck::half_t> && + std::is_same_v, ck::half_t>) + { + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-3, 1e-3); + } + else + { + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } } return true; diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index 1b24bd3bba..3e69caf51e 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -18,6 +18,10 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -220,12 +224,12 @@ int main(int argc, char* argv[]) std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Col{}); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides, Row{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 788f38ec52..ef64dd167d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -48,15 +48,16 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Pool3d_fwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ; // DBetaDstVectorSize -int main() +int main(int argc, char* argv[]) { bool time_kernel = false; @@ -110,6 +110,25 @@ int main() ck::index_t G = 32; ck::index_t C = 64; + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + N = std::stoi(argv[1]); + H = std::stoi(argv[2]); + W = std::stoi(argv[3]); + G = std::stoi(argv[4]); + C = std::stoi(argv[5]); + } + else + { + std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + + return 1; + } + Tensor dy({N, H, W, G, C}); Tensor x({N, H, W, G, C}); Tensor gamma({G, C}); diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index a9e0d3f9ad..ffc6cec61d 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,3 +1,7 @@ +add_example_executable(example_gemm_multi_ABD_wmma_fp16 gemm_multi_ABD_wmma_fp16.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_fastgelu_bf16_i8 gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..a30314f58c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..086a0f4834 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp new file mode 100644 index 0000000000..32345d1263 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 256, + 128, + 32, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 64, 1, 4>, + S<8, 8, 8>>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 12: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..00e2d7e33c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using D1Layout = D0Layout; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 2; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index 5220a4616e..405eac7df1 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{StrideD}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index b424fdaf45..50e670bdf3 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 03a74c04b7..50e1c21c8f 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -80,10 +80,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -119,23 +120,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) K, std::array{StrideA}, std::array{StrideB}, - std::array{0, StrideD}, + std::array{StrideB1, StrideD}, StrideE, a_element_op, b_element_op, @@ -261,7 +270,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n)); + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(m, n), d_m_n(m, n)); } } diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 90e14de59c..a9a30b4c27 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -160,12 +163,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -264,9 +267,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -299,7 +302,6 @@ int main(int argc, char* argv[]) auto ref_op = ReferenceOpInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - Tensor empty_tensor(std::vector{}, std::vector{}); auto ref_argument = ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index ec1b2d6018..4f7414abfa 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -140,12 +143,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); - Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -246,9 +249,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -266,7 +269,7 @@ int main(int argc, char* argv[]) } } - Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) { diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp index 2afe01f02d..0a802ee27d 100644 --- a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -130,11 +130,12 @@ bool run_grouped_conv(bool do_verification, // Fill other lenghts than G,K with 1 and strides with 0 bias_g_k_lengths.fill(1); bias_g_k_strides.fill(0); - bias_g_k_lengths[0] = G; - bias_g_k_lengths[2] = K; - bias_g_k_strides[0] = K; // stride to G - bias_g_k_strides[2] = 1; // stride to K - const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = + HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides, BiasLayout{}); // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) Tensor in(in_g_n_c_wis_desc); diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc index dc2bdc18f0..41c8c42bac 100644 --- a/example/64_fpAintB_gemm/run_gemm_example.inc +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -28,7 +28,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] - Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + Tensor scale_k_n( + HostTensorDescriptor({K, N}, {0, 1_uz}, ck::tensor_layout::BypassLayoutVerification())); switch(config.init_method) { diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp index 53963fc514..8b8cee9e52 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -241,6 +241,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -285,8 +307,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -308,7 +328,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 7a2d0153d9..8da49ef85d 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -162,6 +162,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -216,7 +238,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{StrideD, StrideD}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index fe1eca51b0..3ee4955ae4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -251,6 +251,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -295,8 +317,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -318,7 +338,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 52ba3416a0..72ea7f1cb6 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -287,15 +287,18 @@ int main(int argc, char* argv[]) } } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; @@ -422,7 +425,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( - {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; @@ -463,7 +467,7 @@ int main(int argc, char* argv[]) Tensor b_e_n_k({experts, K, N * 2}); e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); // handle scale before ref. for(int t = 0; t < tokens; ++t) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 92a0cd9e5c..5e306ac6dd 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -264,15 +264,18 @@ int main(int argc, char* argv[]) } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; @@ -488,7 +491,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d0_t_n( - HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 354957c0d1..cc42c4b815 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -292,17 +292,19 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k(HostTensorDescriptor( {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 6ca7d67f53..29e758f9d4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -29,8 +29,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = I4; @@ -239,10 +240,10 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp index 480ca5a0af..ed1c1dc303 100644 --- a/example/66_complex_contraction_bilinear/common_instances.hpp +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 2, ComputeDataType>; // clang-format on template a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // For Imaginary Part of Complex Tensor - Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // Intermediate E tensor Definition - Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; @@ -349,8 +350,10 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { // Real Part Verification - Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_re1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_img1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); auto ref_argument_img = ref_op.MakeArgument( a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index aaf0cb3891..69c0d6558f 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -269,10 +269,12 @@ int main(int argc, char* argv[]) Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -281,12 +283,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -480,7 +483,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -278,12 +280,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -477,7 +480,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -310,12 +313,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -506,7 +510,7 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 829bf9af24..5bb6454d2a 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -268,16 +268,18 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index efbd0f0c03..333f8a3d52 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -303,16 +303,18 @@ int main(int argc, char* argv[]) expert_ids.savetxt("expert_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -321,7 +323,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 3d79f2f6d3..b8ca26193d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -47,7 +47,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 - --optdim 32,64,128,256 + --optdim 32,64,96,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) @@ -169,6 +169,10 @@ if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + target_compile_options(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index cb6cd44f64..2b872cb9b5 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -36,6 +36,13 @@ args: total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) @@ -76,11 +83,20 @@ args: -repeat number of iterations to benchmark the kernel (default:20) -json 0: No Json, 1: Dump Results in Json format (default:0) -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -128,7 +144,16 @@ Note FA use bottom-right by default to express swa case, here we require you exp ### dropout TBD +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. -Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..81d34484a5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,16 +1,19 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" } BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16" } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 0d8f366d8a..e2f69fa49a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8f710050b1..7319ef7ea1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout}; @@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, - {F_dvpad}>>; + ({F_dvpad} > 0)>>; using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; @@ -220,9 +214,9 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; r = fmha_bwd_>(s, a); return r; }} @@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_dpad : str # - F_dvpad : str # + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] F_bias : str # F_dbias : str # F_dropout : str # @@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = DROPOUT_MAP[self.F_dropout], @@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel: def name(self) -> str: def pad_name() -> str: n = '' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' if n != '' : n = 'p' + n return n pn = pad_name() @@ -376,10 +370,18 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -621,8 +623,8 @@ class FmhaBwdApiTrait: dbias : str dropout : str spad1d : str # spad for 1d kernels (dot/convert) - dpad : str - dvpad : str + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] deterministic : str mask_impl : str tr_load : str @@ -651,13 +653,13 @@ class FmhaBwdApiTrait: @property def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' @property def extra_cond(self) -> str: @@ -677,8 +679,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dvpad = 't' if self.dvpad else 'f' return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: @@ -693,8 +696,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dpad = 't' if self.dpad else 'f' return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -720,7 +724,7 @@ class FmhaBwdApiPool: F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, F_convert_dq_bn0=trait.convert_dq_bn0) @@ -793,7 +797,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -804,8 +811,12 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - if tr_load == "t" and (dpad == "t" or dvpad == "t"): + if tr_load == "t": continue # tr_load cannot work with dpad or dvpad + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue if optdim_list != [-1]: if hdim not in optdim_list: continue @@ -861,6 +872,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True if not t.convert_dq_kernel.disabled: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d9452206e7..f898d5f7b2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -25,6 +25,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -163,8 +164,8 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); + + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; @@ -248,22 +249,21 @@ class FmhaFwdApiTrait: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: + + def seqtune(self, max_bm0 : int) -> str: + if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' + else: return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' @@ -351,7 +351,7 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - + if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -378,7 +378,7 @@ class FmhaFwdApiPool: "t": "has_load_tr", "f": "true" } - + per_tr_load =str() for tr_load in ["t", "f"]: per_dtypes=str() @@ -386,6 +386,7 @@ class FmhaFwdApiPool: per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -393,7 +394,7 @@ class FmhaFwdApiPool: F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -534,7 +535,20 @@ class KernelComponentFactory: # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -550,12 +564,16 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } else: return None @@ -567,9 +585,15 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: + squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) @@ -589,11 +613,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -621,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': @@ -630,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue + if dtype != 'fp32': + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -674,27 +702,61 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 + if not cond: + continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0ebeaddf9c..38491b56c4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @dataclass @@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op cond &= pipeline.F_vlayout == 'row' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3b48b3d005..281357ef1e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -769,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -835,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7b93e9654c..3624b7b387 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) elif dtype in ['fp8', 'bf8']: - # TODO - None + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index e0e1fba668..73b3c1e619 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[]) "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -159,7 +159,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == bwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c3bbb7a558..c27a5ce1ae 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -44,21 +48,15 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") + "note when squant=1, this value will be modified") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("squant", "auto", "if using static quantization fusion or not. auto: fp8 will default use squant, " "other will not\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") + "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -69,7 +67,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -89,7 +87,7 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + "\n tf or 2 - trig float\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -113,7 +111,15 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -133,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -148,11 +157,6 @@ auto run(const ck_tile::ArgParser& arg_parser) uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); bool drop_prefs = arg_parser.get_bool("drop_prefs"); std::string mask_str = arg_parser.get_str("mask"); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); @@ -185,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, + seqlen_qpads, seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, @@ -201,11 +208,6 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - range_q, - range_k, - range_v, - range_p, - range_o, squant, is_rotary_interleaved, num_splits, @@ -225,7 +227,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == fwd_result::success ? 0 : -2; } @@ -237,6 +243,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 569c98a458..7ddb65a2db 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -172,6 +183,8 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; }; struct RunConfig @@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args; + ck_tile::fmha_fwd_v3_args args{}; args.data_type = problem.data_type; args.batch = problem.batch; @@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..6cd1cd94fa 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -15,6 +15,10 @@ #include #include +struct FmhaBwdFp32 +{ +}; + struct FmhaBwdFp16 { }; @@ -26,6 +30,26 @@ struct FmhaBwdBf16 template struct FmhaBwdTypeConfig; +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using GemmDataType = float; + using BiasDataType = float; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = float; + using OGradDataType = float; + using QGradDataType = float; + using KGradDataType = float; + using VGradDataType = float; + using BiasGradDataType = float; +}; + template <> struct FmhaBwdTypeConfig { @@ -368,8 +392,8 @@ template +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-4; + double atol = 1e-4; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { @@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -763,15 +773,21 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type != mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..761def6d6a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,10 @@ #include #include +struct FmhaFwdFp32 +{ +}; + struct FmhaFwdFp16 { }; @@ -41,9 +45,29 @@ struct FmhaFwdFp8Bf16 { }; +struct FmhaFwdFp8Fp32 +{ +}; + template struct FmhaFwdTypeConfig; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = float; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + template <> struct FmhaFwdTypeConfig { @@ -108,6 +132,38 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; @@ -126,11 +182,20 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; + // Optional cumulative sequence length arrays + // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + // Group mode: seqstart_padded_* provide physical starts including PAD (optional) + const void* seqstart_padded_q_ptr = nullptr; // [batch+1] + const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -518,7 +583,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.seqstart_padded_q_ptr, + args.seqstart_padded_k_ptr); } else { // create batch mode kernel arguments @@ -564,7 +631,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 397245ab32..0703af71e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(std::string /*init_method*/) { @@ -50,20 +58,30 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); } int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) @@ -141,7 +159,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -157,11 +178,6 @@ fwd_result fmha_fwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, - float range_q, - float range_k, - float range_v, - float range_p, - float range_o, bool squant, bool is_rotary_interleaved, ck_tile::index_t num_splits, @@ -172,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -180,6 +198,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8"; else if constexpr(std::is_same_v) return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; else static_assert(false); }(); @@ -290,6 +312,24 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) + const bool has_group_padding = + (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || + (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); + const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || + !kv_eff_lens_per_batch.empty())); + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + if((using_appendkv || using_pagedkv || using_splitkv) && + (has_group_padding || has_batch_efflens)) + { + std::cerr << "Padding (physical or effective lengths) is not supported with " + "appendkv/splitkv/pagedkv pipelines" + << std::endl; + return fwd_result::invalid_args; + } + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = generate_missing_seqlens(mode, batch, @@ -353,6 +393,44 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -367,22 +445,6 @@ fwd_result fmha_fwd_run(mode_enum mode, using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); - scale_p = p_dtype_max / range_p; - scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); - } - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -452,8 +514,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck_tile::index_t shape_seqlen_q = + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -511,7 +580,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -528,7 +597,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -576,32 +645,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") - { - // suitable for fp8 quantization - if(!squant) - { - std::cerr << "init method " << init_method << " can not be used without quantization" - << std::endl; - return fwd_result::invalid_args; - } - ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); - - // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [0.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); - } - else - { - std::cerr << "Unknown value for init argument: " << init_method << std::endl; - return fwd_result::invalid_args; - } - if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -625,8 +668,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -635,6 +678,16 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -650,15 +703,90 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); - knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() - : seqstart_k_with_padding_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -711,6 +839,54 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cout << ", cache_batch_idx:" << use_cache_batch_idx; } #endif + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_efflens) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + std::cout << std::flush; const auto init_traits = [&](auto& traits) { @@ -794,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -810,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -925,6 +1101,29 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } } else if constexpr(std::is_same_v>) { @@ -964,7 +1163,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - const float appendkv_ave_time = [&] { + auto run_appendkv = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { @@ -974,18 +1173,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_appendkv_args fwd_appendkv_args; init_args(fwd_appendkv_args); - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); } #endif return 0.0f; - }(); + }; + const float appendkv_ave_time = run_appendkv(stream_config); if(appendkv_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return fwd_result::no_instance; } - const float fwd_ave_time = [&] { + auto run_fwd = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -995,8 +1195,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_pagedkv_args fmha_pagedkv_args; init_args(fmha_pagedkv_args); - const float ave_time = - fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); #if CK_TILE_FMHA_FWD_SPLITKV_API // If there is no instance for these args, fallback to fmha_fwd_splitkv if(ave_time >= 0.0f) @@ -1015,7 +1214,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_splitkv_args fmha_splitkv_args; init_args(fmha_splitkv_args); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); } #endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; @@ -1024,8 +1223,9 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); - return fmha_fwd(fmha_traits, fmha_args, stream_config); - }(); + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -1099,11 +1299,24 @@ fwd_result fmha_fwd_run(mode_enum mode, } else { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v; + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; auto p_compute_element_func = [&]() { if constexpr(supports_squant) @@ -1113,9 +1326,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); auto oacc_element_func = [&]() { - if constexpr(supports_squant) + if constexpr(std::is_same_v && supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; else return ck_tile::identity{}; }(); @@ -1127,15 +1342,29 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1498,8 +1727,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 10cb5149a4..4bd1d1a367 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,6 +56,11 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; + + // Optional batch-mode cumulative seqlen overrides (exclude PAD) + // If provided, they override per-batch effective lengths to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index e0fbad39a5..194675f962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,7 +158,9 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt); + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..31ad800039 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index b847e85398..a3f7d68eb3 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,3 +23,20 @@ done done done done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt index 1497d491bb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -1,31 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt index 90c5e2b7fb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -1,4 +0,0 @@ -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index e7babd2744..5c2a5a4b3d 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -34,15 +34,15 @@ function print_log_header(){ } #run verification tests -example/ck_tile/01_fmha/script/smoke_test_fwd.sh -example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" print_log_header $fmha_fwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" print_log_header $fmha_bwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 3b59505ff0..cd51dde2d4 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -6,7 +6,7 @@ SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) EXE_NAME=tile_example_fmha_bwd EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 -GPU_arch=$GPU_arch +GPU_arch=${GPU_arch:-""} if [ -z "$GPU_arch" ] ; then GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') fi @@ -31,7 +31,17 @@ run_exe() { set -ex } +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + set -x +# main tests for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 256 ; do @@ -40,21 +50,21 @@ for bias in "n" "a" ; do for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done -run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS - -done -done -done -done -done -done -done +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index c087a1fb3e..fca6b8d0cd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -94,7 +94,30 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - run_exe -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + done ; done ; done ; done } @@ -114,10 +137,121 @@ run_fp16_appendkv_tests() { done ; done ; done } +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + set -x run_fp16_bf16_tests -# run_fp8_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests +run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index f0a71670ad..8c7589dabb 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -231,7 +231,7 @@ struct SplitKTwoStageInvoker preprocess = clear_gemm_output; } - return ck_tile::launch_kernel_time_mask( + ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel( @@ -245,6 +245,8 @@ struct SplitKTwoStageInvoker ck_tile::make_tuple(args.N, 1), // Output Stride input_tensors, static_cast(c_ptr))); + + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 588b66ca43..07b925d0eb 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -72,6 +72,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; + static constexpr bool TiledMMAPermuteN = false; }; template diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index b47dd8d8a7..d737a0f864 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -109,7 +109,7 @@ struct WeightPreshuffleInvoker GemmConfig::NumWaveGroups, false, 1, - true>>; + GemmConfig::TiledMMAPermuteN>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 19855c7f72..e0d97a50db 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -174,10 +174,12 @@ struct UniversalInvoker preprocess = clear_gemm_output; } - return ck_tile::launch_kernel_time_mask( + ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 4f3b173c55..bbfb2df006 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,10 +1,12 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) +add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 94481fa7b7..09bf3e167a 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -1,140 +1,8 @@ -# Grouped Gemm - -Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation. - ## Quick Tour for New Users The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads. -Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function. - -### Key Arguments -The example takes several arguments including `group_count`, `repeat`, and `warmup`: -- `group_count`: the number of GEMM operations in the group -- `repeat`: the number of times to repeat the kernel for benchmarking -- `warmup`: the number of iterations before the actual kernel run time measure - -```cpp -// Example -const int group_count = arg_parser.get_int("group_count"); -const int repeat = arg_parser.get_int("repeat"); -const int warmup = arg_parser.get_int("warmup"); -``` -In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`. - -```cpp -// Example -std::vector Ms = arg_parser.get_int_vec("Ms"); -std::vector Ns = arg_parser.get_int_vec("Ns"); -std::vector Ks = arg_parser.get_int_vec("Ks"); -std::vector stride_As = arg_parser.get_int_vec("stride_As"); -std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); -std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); -``` -Where: -- `Ms` is the M dimension of each GEMM. -- `Ns` is the N dimension of each GEMM. -- `Ks` is the K dimension of each GEMM. -- `stride_As` is the stride values for matrix A. -- `stride_Bs` is the stride values for matrix B. -- `stride_Cs` is the stride values for matrix C. - -### HostTensor and Device Memory Buffers (for CPU and GPU) -Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations. -The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where: -- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation. -- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation. - -#### HostTensor Buffers (for CPU) -In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors: -```cpp -// Example -std::vector> a_m_k_tensors; -std::vector> b_k_n_tensors; -std::vector> c_m_n_tensors; -``` -Where: -- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation. -- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`). -- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`). - -The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`. - -#### Device Memory Buffers (for GPU) -Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation.. -```cpp -// Example -std::vector> a_m_k_dev_buf; -std::vector> b_k_n_dev_buf; -std::vector> c_m_n_dev_buf; -``` -Where: -- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU. -- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU. -- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU. - -## Prepare data -In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor. - -Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order. - -```cpp -// Example, API -template -auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant) { - // code -} -``` - -Where: -- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false). -- `row` is the number of rows in the matrix. -- `col` is the number of columns in the matrix. -- `stride` is the current stride (the distance between consecutive elements in memory). -- `bool_constant` is a tag type that helps in differentiating behavior at compile-time. - -Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout. - -```cpp -// Example for tensor A -ck_tile::HostTensor(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))) -``` - -After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`. - -The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as: -```cpp -struct GroupedGemmHostArgs -{ - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; -}; -``` -The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`: -```cpp -// Example, API -template -float invoke_gemm(int n_warmup, - int n_repeat, - int group_count, - const std::vector& args) -``` -`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace: - -```cpp -// Example -ck_tile::DeviceMem gemm_workspace; -gemm_workspace.Realloc(GetWorkspaceSize(args)); -``` - -### Advanced Features: Preshuffle and Persistence +### Preshuffle and Persistence The grouped GEMM examples include two advanced optimization features: @@ -142,28 +10,28 @@ The grouped GEMM examples include two advanced optimization features: Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches. - **Implementation**: Available in `grouped_gemm_preshuffle.cpp` -- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration +- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration - **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts -- **Benefits**: Improved memory efficiency and reduced data movement + #### Persistence Mode Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy. - **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm` - **Usage**: `invoke_gemm` enables persistence -- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes -Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. +#### Multi-D Operations +Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output. -Finally the arguments are passed to group_gemm and the kernel is launched. -```cpp -// API -template -float grouped_gemm(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* kargs_ptr) -``` -All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched. +- **Implementation**: Available in `grouped_gemm_multi_d.cpp` +- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result) +- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors +- **Data Types**: Supports fp16 +- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call +- **Build Target**: `make tile_example_grouped_gemm_multi_d -j` + +Multi-D operations supports both persistence and non-persistence modes. +Weight preshuffle supports only on non-persistence mode. ## Build ``` @@ -175,10 +43,13 @@ mkdir build && cd build make tile_example_grouped_gemm -j # The preshuffle example make tile_example_grouped_gemm_preshuffle -j +# The multi-D operations example +make tile_example_grouped_gemm_multi_d -j # The quant grouped gemm fp8 example make tile_example_quant_grouped_gemm -j ``` -This will result in an executable `build/bin/tile_example_grouped_gemm` +Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. + ## example ``` @@ -213,4 +84,4 @@ K[i] = 512 + 384 * i stride_A[i] = K[i] stride_B[i] = K[i] stride_C[i] = N[i] -``` +``` \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 6493a542ba..10d7befc06 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,7 +9,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 @@ -296,7 +295,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>; std::pair create_args(int argc, char* argv[]) { @@ -325,7 +324,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } template diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp new file mode 100644 index 0000000000..98b0428d39 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm_multi_d.hpp" + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +template +float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + +#include "run_grouped_gemm_multi_d_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_grouped_gemm_multi_d_example(argc, argv); +#else + return !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv); +#endif +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp new file mode 100644 index 0000000000..0789452ada --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; +using EDataType = ck_tile::half_t; +using DsDataType = ck_tile::tuple; +using AccDataType = float; + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet + static constexpr bool Persistent = false; // currently persistent == true is not supported yet + static constexpr bool DoubleSmemBuffer = + false; // currently double smem buffer == true is not supported yet +}; + +struct GemmConfigMemory : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; +struct GemmConfigV4 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; + +std::pair create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Ds", "", "Tensor Ds strides - it is empty by default.") + .insert("stride_Es", "", "Tensor E strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default.") + .insert("e_layout", "R", "E tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp16", "data type. fp16") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); +} + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 83542e76f1..409bb173a1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -13,7 +13,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_group_quant.hpp" +#include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" @@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, constexpr auto memory_operation = memory_operation_.value; constexpr bool transpose_c = false; - using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem; + using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 17e0ee5342..10d317a2c7 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -183,12 +183,24 @@ int run_grouped_gemm_example_with_layouts(int argc, if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 128 * i); + // Let get_default_stride calculate based on layout stride_As.push_back(0); stride_Bs.push_back(0); stride_Cs.push_back(0); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 1cd2212994..f822c7d8a7 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -88,7 +88,7 @@ float invoke_gemm(int n_warmup, // The contents of the memory pointed to by `kargs_ptr` pointer could be // written by e.g. another kernel from earlier stage. - std::vector kargs; + std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) @@ -109,7 +109,7 @@ float invoke_gemm(int n_warmup, const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, kargs.data(), - kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>), hipMemcpyHostToDevice, stream.stream_id_)); ave_time = grouped_gemm_tileloopGetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + gemm_descs.push_back({p_a, + p_b, + {/*ds_ptr*/}, + p_c, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + {/*stride_Ds*/}, + stride_Cs[i]}); } float ave_time = invoke_gemm + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_multi_d( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = + grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + } + return ave_time; +} + +template +int run_grouped_gemm_multi_d_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + using CDElementWise = MultiplyMultiply; + using DsLayout = ck_tile::tuple; + + auto valid_input_data = [&](int group_count, const auto&... args) { + return !(args.empty() || ...) && group_count == (args.size() == ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_D0 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_D1 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_Es = arg_parser.get_int_vec("stride_Es"); + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_D0, stride_D1, stride_Es)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, " + "896, 1280..), stride_As (Ks), stride_Bs (Ks), stride_D0 (Ns), stride_D1 " + "(Ns), stride_Es (Ns)" + << std::endl; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + stride_Es.push_back(Ns[i]); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> d0_m_n_tensors; + std::vector> d1_m_n_tensors; + std::vector> e_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + d0_m_n_tensors.reserve(group_count); + d1_m_n_tensors.reserve(group_count); + e_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> d0_m_n_dev_buf; + std::vector> d1_m_n_dev_buf; + std::vector> e_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + d0_m_n_dev_buf.reserve(group_count); + d1_m_n_dev_buf.reserve(group_count); + e_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + + stride_D0[i] = ck_tile::get_default_stride(M, N, stride_D0[i], is_row_major(d0_layout)); + stride_D1[i] = ck_tile::get_default_stride(M, N, stride_D1[i], is_row_major(d1_layout)); + + stride_Es[i] = ck_tile::get_default_stride(M, N, stride_Es[i], is_row_major(e_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + + d0_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D0[i], is_row_major(d0_layout)))); + d1_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D1[i], is_row_major(d1_layout)))); + + e_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Es[i], is_row_major(e_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " d0_m_n: " << d0_m_n_tensors[i].mDesc + << " d1_m_n: " << d1_m_n_tensors[i].mDesc << " e_m_n: " << e_m_n_tensors[i].mDesc + << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); + + b_k_n_dev_buf.push_back(std::make_unique(b_k_n_tensors[i])); + + d0_m_n_dev_buf.push_back(std::make_unique(d0_m_n_tensors[i])); + d1_m_n_dev_buf.push_back(std::make_unique(d1_m_n_tensors[i])); + e_m_n_dev_buf.push_back(std::make_unique(e_m_n_tensors[i])); + + e_m_n_dev_buf[i]->SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer(); + + std::array ds_ptr_buf = { + d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()}; + std::array stridesDs = {stride_D0[i], stride_D1[i]}; + + gemm_descs.push_back({p_a, + p_b, + ds_ptr_buf, + p_e, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + stridesDs, + stride_Es[i]}); + } + + float ave_time = invoke_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name{"Grouped Gemm Multiple-D"}; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + flop += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + }); + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(EDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + std::vector> e_m_n_host_refs; + e_m_n_host_refs.reserve(group_count); + + // copy e_m_n_tensors result from device to host and initialize host tensors to zero + for(int i = 0; i < group_count; i++) + { + e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + e_m_n_host_refs.push_back(ck_tile::HostTensor( + host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], is_row_major(e_layout)))); + + e_m_n_host_refs[i].SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tensors[i], + b_k_n_tensors[i], + {d0_m_n_tensors[i], d1_m_n_tensors[i]}, + e_m_n_host_refs[i]); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); + + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + + pass &= + ck_tile::check_err(e_m_n_tensors[i], + e_m_n_host_refs[i], + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_grouped_gemm_multi_d_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string ds_layout = arg_parser.get_str("ds_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_grouped_gemm_multi_d_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index 5cb1d2650e..10332137e2 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -7,5 +7,8 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp) +target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp) target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp index c456e7a2e8..fa914a7119 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -11,190 +11,14 @@ #include "ck_tile/host.hpp" #include "grouped_convolution_utils.hpp" - -template , - typename DsLayout = ck_tile::tuple<>, - typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, - const ck_tile::stream_config& s) -{ - constexpr int kBlockPerCu = 1; - - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; - - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - // Implicit GEMM Traits - using CodegenShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType = - ck_tile::GroupedConvTraits; - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using ConvEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << '\n' - << "Vector size A: " << CodegenPipeline::GetVectorSizeA() - << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - if(args.k_batch == 1) - { - return Run(ck_tile::integral_constant{}); - } - else - { - return Run(ck_tile::integral_constant{}); - } -} - +#include "grouped_convolution_backward_data_invoker.hpp" #include "run_grouped_convolution_bwd_data_example.inc" -template -int run_grouped_conv_bwd_data_example_prec_type( - std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) -{ - using NWGC = ck_tile::tensor_layout::convolution::NWGC; - using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; - using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; - - using GKXC = ck_tile::tensor_layout::convolution::GKXC; - using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; - using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; - - using NWGK = ck_tile::tensor_layout::convolution::NWGK; - using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; - using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; - - if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NWGC{}, GKXC{}, NWGK{}); - } - else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); - } - else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); - } - else - { - throw std::runtime_error("Unsupported memory layout!"); - } -} - template int run_grouped_conv_bwd_data_example(int argc, char* argv[]) { + using Invoker = GroupedConvolutionBackwardDataInvoker; + auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; @@ -206,12 +30,16 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[]) if(data_type == "fp16") { - return run_grouped_conv_bwd_data_example_prec_type( + return run_grouped_conv_bwd_data_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else if(data_type == "bf16") { - return run_grouped_conv_bwd_data_example_prec_type( + return run_grouped_conv_bwd_data_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp new file mode 100644 index 0000000000..1b3d45427d --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionBackwardDataInvoker +{ + + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 1; + constexpr ck_tile::index_t VectorSizeB = 1; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } + } +}; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index d22b0ff58c..82614bbb13 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -11,195 +11,13 @@ #include "ck_tile/host.hpp" #include "grouped_convolution_utils.hpp" - -template , - typename DsLayout = ck_tile::tuple<>, - typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, - const ck_tile::stream_config& s) -{ - constexpr int kBlockPerCu = 1; - - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; - - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - // Implicit GEMM Traits - using CodegenShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType = - ck_tile::GroupedConvTraits; - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using ConvEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << '\n' - << "Vector size A: " << CodegenPipeline::GetVectorSizeA() - << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - float ave_time = ck_tile::launch_kernel_time_mask( - s, - Kernel::Preprocess(kargs, s), - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - if(args.k_batch == 1) - { - return Run(ck_tile::integral_constant{}); - } - else - { - return Run(ck_tile::integral_constant{}); - } -} - +#include "grouped_convolution_backward_weight_invoker.hpp" #include "run_grouped_convolution_bwd_weight_example.inc" -template -int run_grouped_conv_bwd_weight_example_prec_type( - std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) -{ - using NWGC = ck_tile::tensor_layout::convolution::NWGC; - using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; - using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; - - using GKXC = ck_tile::tensor_layout::convolution::GKXC; - using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; - using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; - - using NWGK = ck_tile::tensor_layout::convolution::NWGK; - using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; - using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; - - if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") - { - return run_grouped_conv_bwd_weight_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NWGC{}, GKXC{}, NWGK{}); - } - else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") - { - return run_grouped_conv_bwd_weight_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); - } - else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") - { - return run_grouped_conv_bwd_weight_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); - } - else - { - throw std::runtime_error("Unsupported memory layout!"); - } -} - template -int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) +int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; + using Invoker = GroupedConvolutionBackwardWeightInvoker; std::string data_type = arg_parser.get_str("prec"); std::string in_layout = arg_parser.get_str("in_layout"); @@ -208,13 +26,17 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) if(data_type == "fp16") { - return run_grouped_conv_bwd_weight_example_prec_type( - in_layout, wei_layout, out_layout, argc, argv); + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, arg_parser); } else if(data_type == "bf16") { - return run_grouped_conv_bwd_weight_example_prec_type( - in_layout, wei_layout, out_layout, argc, argv); + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, arg_parser); } else { @@ -224,9 +46,22 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) int main(int argc, char* argv[]) { + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + try + { #if CK_TILE_USE_WMMA - return !run_grouped_conv_bwd_weight_example(argc, argv); + return !run_grouped_conv_bwd_weight_example(arg_parser); #else - return !run_grouped_conv_bwd_weight_example(argc, argv); + return !run_grouped_conv_bwd_weight_example(arg_parser); #endif + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp new file mode 100644 index 0000000000..f46707d1d2 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionBackwardWeightInvoker +{ + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 1; + constexpr ck_tile::index_t VectorSizeB = 1; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel_time_mask( + s, + Kernel::Preprocess(kargs, s), + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } + } +}; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage.cpp new file mode 100644 index 0000000000..82068156e5 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" +#include "grouped_convolution_backward_weight_two_stage_invoker.hpp" +#include "run_grouped_convolution_bwd_weight_example.inc" + +template +int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser) +{ + using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, arg_parser); + } + else if(data_type == "bf16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported data type for this operation!"); + } +} + +int main(int argc, char* argv[]) +{ + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + try + { +#if CK_TILE_USE_WMMA + return !run_grouped_conv_bwd_weight_example(arg_parser); +#else + return !run_grouped_conv_bwd_weight_example(arg_parser); +#endif + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp new file mode 100644 index 0000000000..a8f321e7f9 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionBackwardWeightTwoStageInvoker +{ + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) + { + using WorkspaceDataType = float; + + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 1; + constexpr ck_tile::index_t VectorSizeB = 1; + constexpr ck_tile::index_t VectorSizeC = 1; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + OutDataType, // A: Out + InDataType, // B: In + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * + sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = + ck_tile::GroupedConvBwdWeightHostArgs(args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto kargs = Kernel::MakeKernelArgs(ws_args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(args.G_ * args.K_), + static_cast(args.C_ * spatial_lengths_accum)}; + + for(auto d : shape) + total_elements *= d; + + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = + (total_elements + elements_per_block - 1) / elements_per_block; + + auto input_tensors = + ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + auto preprocess = [&]() { + if(args.k_batch > 1) + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s.stream_id_)); + }; + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } + } +}; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index d74fa1b292..4cddbae3ab 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -11,181 +11,14 @@ #include "ck_tile/host.hpp" #include "grouped_convolution_utils.hpp" - -template , - typename DsLayout = ck_tile::tuple<>, - typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) -{ - constexpr int kBlockPerCu = 1; - - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; - - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - // Implicit GEMM Traits - using CodegenShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType = - ck_tile::GroupedConvTraits; - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using ConvEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << '\n' - << "Vector size A: " << CodegenPipeline::GetVectorSizeA() - << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - return Run(ck_tile::integral_constant{}); -} - +#include "grouped_convolution_forward_invoker.hpp" #include "run_grouped_convolution_fwd_example.inc" -template -int run_grouped_conv_fwd_example_prec_type( - std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) -{ - using NWGC = ck_tile::tensor_layout::convolution::NWGC; - using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; - using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; - - using GKXC = ck_tile::tensor_layout::convolution::GKXC; - using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; - using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; - - using NWGK = ck_tile::tensor_layout::convolution::NWGK; - using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; - using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; - - if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NWGC{}, GKXC{}, NWGK{}); - } - else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); - } - else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); - } - else - { - throw std::runtime_error("Unsupported memory layout!"); - } -} - template int run_grouped_conv_fwd_example(int argc, char* argv[]) { + using Invoker = GroupedConvolutionForwardInvoker; + auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; @@ -197,12 +30,12 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) if(data_type == "fp16") { - return run_grouped_conv_fwd_example_prec_type( + return run_grouped_conv_fwd_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else if(data_type == "bf16") { - return run_grouped_conv_fwd_example_prec_type( + return run_grouped_conv_fwd_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp new file mode 100644 index 0000000000..0b9879d247 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionForwardInvoker +{ + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + return Run(ck_tile::integral_constant{}); + } +}; diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc index 8519daaac2..3d7635bf4f 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -4,6 +4,7 @@ template ( + float ave_time = Invoker::template grouped_conv_bwd_data( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -39,6 +40,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, template +int run_grouped_conv_bwd_data_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index eedc341235..5e640514f9 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -4,6 +4,7 @@ template ( + float ave_time = Invoker::template grouped_conv_bwd_weight( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::size_t flop = args.GetFlops(); - std::size_t num_byte = args.GetByte(); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; - return ave_time; } template -int run_grouped_conv_bwd_weight_example_with_layouts( - int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_parser, + const InLayout, + const WeiLayout, + const OutLayout) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - using AccDataType = float; std::vector filter_spatial_lengths; @@ -138,17 +130,27 @@ int run_grouped_conv_bwd_weight_example_with_layouts( std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl; - invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); + float ave_time = invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); weight_dev_buf.FromDevice(weight.data()); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + bool pass = true; if(arg_parser.get_int("v") == 1) @@ -189,3 +191,61 @@ int run_grouped_conv_bwd_weight_example_with_layouts( return pass; } + +template +int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout, + std::string wei_layout, + std::string out_layout, + ck_tile::ArgParser& arg_parser) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + arg_parser, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + arg_parser, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + arg_parser, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc index c5ae92a0da..beb6005e19 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc @@ -4,6 +4,7 @@ template ( + float ave_time = Invoker::template grouped_conv_fwd( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -39,6 +40,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, template +int run_grouped_conv_fwd_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 94d3e70be1..e9fbeafde1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -211,7 +211,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index ff7ec1517e..1b101c2e5f 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -157,7 +157,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index 16e9832c07..7cdb5cc0d1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -156,7 +156,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index c5a08d910e..4e19cfd688 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -193,7 +193,9 @@ auto string_to_op(const std::string& op) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 9acc4f9bfc..7f8aba7b3d 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline - BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline - Row and Column-wise scaled: scaling implemented in Epilogue +- Tensor-wise scaled: scaling implemented in Epilogue ## build ``` @@ -14,7 +15,6 @@ mkdir build && cd build ../script/cmake-ck-dev.sh ../ # Compile the quant kernels make tile_example_gemm_quant_basic -j -make tile_example_gemm_bquant_basic -j ``` This will result in an executable `build/bin/tile_example_gemm_quant_basic` @@ -37,7 +37,7 @@ args: -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -quant_mode Which quant method to use (aquant, rowcol) + -quant_mode Which quant method to use (aquant, bquant, tensor, rowcol) ``` User need to select correct mapping of config for each quant mode: @@ -47,5 +47,6 @@ User need to select correct mapping of config for each quant mode: | For selecting AQuant | aquant | GemmConfigQuant | | For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant | | For selecting BQuant | bquant | GemmConfigQuant | +| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill | For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant | diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp old mode 100644 new mode 100755 index 79c6cca6cb..fa9ad967ad --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -23,7 +23,6 @@ template ); - // B datatype is safe to use as compute type as it should be at least fp8 using ComputeDataType = std::conditional_t; + QuantMode, + ALayout, // for AQLayout + BLayout, // for BQLayout + GemmConfig::DoubleSmemBuffer>; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -66,19 +72,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str constexpr auto tail_number_v = tail_number_.value; constexpr bool transpose_c = false; + // row-col and tensor quants use the regular pipeline, A/B quants use their own using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant, - ck_tile::GemmRowColQuantPipelineProblem, + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, + ck_tile::GemmRowColTensorQuantPipelineProblem, std::conditional_t>>; using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant, + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper + rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString( + hipMemsetAsync(args.c_ptr, + 0, + args.M * args.N * sizeof(typename TypeConfig::CDataType), + s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } return ave_time; }; @@ -177,6 +229,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + if((QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) && + GemmConfig::PreshuffleB) + { + throw std::runtime_error( + "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) @@ -241,10 +301,18 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + TypeConfig, + 128, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error( - "Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'"); + "Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'"); } } else if(data_type == "bf8") @@ -276,10 +344,18 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + TypeConfig, + 128, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error( - "Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'"); + "Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'"); } } else if(data_type == "i4fp8") @@ -372,4 +448,7 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_gemm_example(argc, argv); +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index ccf07460fa..cfe7b72af9 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -9,7 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_group_quant.hpp" +#include "ck_tile/ops/gemm_quant.hpp" template constexpr ck_tile::index_t get_k_warp_tile() @@ -91,6 +91,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr bool PreshuffleQuant = false; + static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; }; @@ -145,6 +146,26 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase static constexpr bool PreshuffleQuant = true; }; +template +struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + get_k_from_preshuffled_warp_tile(); + + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template * t, int block_aq_k) return ck_tile::reference_permute(t_view, {1, 0, 2}); } +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + template a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); @@ -237,28 +256,30 @@ int run_gemm_example_with_layouts(int argc, // Create AQ tensor with appropriate shape std::unique_ptr> aq_tensor_ptr = nullptr; - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) { aq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); } - else if(QuantMode == ck_tile::QuantType::RowColQuant) + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { aq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout))); } - // Create BQ tensor only for RowColQuant mode + // Create BQ tensor with appropriate shape std::unique_ptr> bq_tensor_ptr = nullptr; - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); } - else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout))); } std::random_device rd; @@ -282,7 +303,7 @@ int run_gemm_example_with_layouts(int argc, *bq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } - else + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { if constexpr(std::is_same_v) { @@ -296,12 +317,15 @@ int run_gemm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); } } else if(init_method == 1) @@ -343,7 +367,8 @@ int run_gemm_example_with_layouts(int argc, std::unique_ptr aq_dev_buf_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { aq_dev_buf_ptr = std::make_unique(aq_tensor_ptr->get_element_space_size_in_bytes()); @@ -351,14 +376,16 @@ int run_gemm_example_with_layouts(int argc, std::unique_ptr bq_dev_buf_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { bq_dev_buf_ptr = std::make_unique(bq_tensor_ptr->get_element_space_size_in_bytes()); } if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { if constexpr(GemmConfig::PreshuffleQuant) { @@ -383,22 +410,33 @@ int run_gemm_example_with_layouts(int argc, { a_m_k_dev_buf.ToDevice(a_m_k.data()); } + + ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(std::is_same_v) { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor b_k_n_dev = b_k_n; + + if constexpr(GemmConfig::PreshuffleB) + { + b_k_n_dev = shuffle_b(b_k_n); + } ck_tile::permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(GemmConfig::PreshuffleB) + { + b_k_n_dev = shuffle_b(b_k_n); + } + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } + c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); } @@ -412,15 +450,9 @@ int run_gemm_example_with_layouts(int argc, CLayout, QuantGroupSize, QuantMode>(a_m_k_dev_buf, - (QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) - ? aq_dev_buf_ptr.get() - : nullptr, + aq_dev_buf_ptr.get(), b_k_n_dev_buf, - (QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) - ? bq_dev_buf_ptr.get() - : nullptr, + bq_dev_buf_ptr.get(), c_m_n_dev_buf, M, N, @@ -467,7 +499,7 @@ int run_gemm_example_with_layouts(int argc, QuantGroupSize, false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); } - else + else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { ck_tile::reference_gemm_rowcol_quant( a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref); } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + ck_tile::reference_gemm_tensor_quant( + a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -494,7 +536,7 @@ int run_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; } - std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { diff --git a/example/ck_tile/40_streamk_gemm/CMakeLists.txt b/example/ck_tile/40_streamk_gemm/CMakeLists.txt index 8c200e9bd2..3b285a54b5 100644 --- a/example/ck_tile/40_streamk_gemm/CMakeLists.txt +++ b/example/ck_tile/40_streamk_gemm/CMakeLists.txt @@ -1,12 +1,10 @@ -add_executable(tile_example_gemm_streamk EXCLUDE_FROM_ALL streamk_gemm.cpp) -set(EXAMPLE_GEMM_COMPILE_OPTIONS) -set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) -if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp) + set(EXAMPLE_GEMM_COMPILE_OPTIONS) + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile streamk gemm tests for current target") endif() -list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") -target_compile_options(tile_example_gemm_streamk PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/40_streamk_gemm/README.md b/example/ck_tile/40_streamk_gemm/README.md new file mode 100644 index 0000000000..d2ff7eabc0 --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/README.md @@ -0,0 +1,37 @@ +# Stream-K GEMM + +This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx942) or leave it blank +../script/cmake-ck-dev.sh ../ +# Compile the Stream-K kernels +make tile_example_streamk_gemm_basic -j +``` +This will result in an executable `build/bin/tile_example_streamk_gemm_basic` + +## example +``` +args: + -m m dimension (default:512) + -n n dimension (default:512) + -k k dimension (default:512) + -a_layout tensor A data layout (default: R) + -b_layout tensor B data layout (default: C) + -c_layout tensor C data layout (default: R) + -num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1) +-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic) + -stride_a tensor A stride (default:0) + -stride_b tensor B stride (default:0) + -stride_c tensor C stride (default:0) + -v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) + -prec data type. fp16/bf16 (default:fp16) + -warmup number of iterations before benchmarking the kernel (default:50) + -repeat number of iterations to benchmark the kernel (default:100) + -timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu) + -init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache flush the cache before running the kernel (default:true) +``` \ No newline at end of file diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index b108e992ff..080308e319 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -1,62 +1,17 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once - -#include - #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/utility/json_dump.hpp" - -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 - -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} struct GemmConfigBase { - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; @@ -65,19 +20,16 @@ struct GemmConfigBase static constexpr bool UseStructuredSparsity = false; static constexpr bool Persistent = false; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; + static constexpr bool DoubleSmemBuffer = false; }; template struct GemmConfigMemoryInterwave : public GemmConfigBase { - // Memory friendly for Interwave scheduler static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 32; @@ -89,168 +41,59 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -template -struct GemmConfigMemoryIntrawave : public GemmConfigBase +template +struct StreamKGemmTypeConfig { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; -}; - -template -struct GemmConfigComputeV3 : public GemmConfigBase -{ - // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; -}; - -template -struct GemmConfigPreshufflePrefill : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = false; - static constexpr bool DoubleSmemBuffer = false; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; -}; - -template -struct GemmTypeConfig; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; + using ADataType = ADataType_; + using BDataType = BDataType_; using AccDataType = float; - using CDataType = ck_tile::half_t; + using CDataType = CDataType_; }; template struct DataTypeTraits; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + template <> struct DataTypeTraits { static constexpr const char* name = "fp8"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + template struct PipelineTypeTraits; -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - template - using UniversalGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - template - using UniversalGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") + arg_parser.insert("m", "512", "m dimension") + .insert("n", "512", "n dimension") + .insert("k", "512", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") @@ -260,35 +103,17 @@ auto create_args(int argc, char* argv[]) .insert("reduction_strategy", "atomic", "strategy for storing results in C tensor - atomic/reduction") - .insert( - "occupancy", - "-1", - "maximum number of workgroups per CU - value of -1 queries occupancy from the device") - .insert("num_cu", - "-1", - "number of compute units (CUs) - value of -1 uses number of CUs on the device") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t") - .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("prec", "fp16", "data type. fp16/bf16") + .insert("warmup", "50", "number of iterations before benchmarking the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("persistent", "0", "0:non-persistent, 1:persistent") - .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "gemm.json", "json file name to dump results") - .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") - .insert("rotating_count", "1000", "rotating count, defaults to 1000"); + .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } - -// Type aliases for memory operation integral constants -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 1f916de2e7..5fdf6b29ef 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -1,11 +1,35 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once -template -static constexpr inline auto is_row_major(Layout layout_) + +// Estimate the number of WGs contributing to the same macro tile in C +template +int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner) { - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; + // In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a + // macro time in C + int num_wgs_per_tile = 1; + + // Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C + if(tile_partitioner.sk_num_blocks > 0 && + ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + { + // Determine the number of iterations per WG for a given macro tile in C + uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1; + + // Estimate the number of WGs per macro tile + num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) + + ((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0); + } + + return std::max(num_wgs_per_tile, 1); +} + +template +static constexpr inline auto is_row_major(Layout) +{ + return ck_tile::bool_constant< + std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>{}; } template @@ -20,7 +44,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation + // Calculate error due to multiple WGs working in the same C macro tile const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( @@ -30,123 +54,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K, } template -void permute_tensor_b(Tensor& tensor) -{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; - - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< - UniversalGemmProblem>; - - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); - const ck_tile::index_t K0 = K / K1; - - Tensor tensor_copy = tensor; - - // int K0, N, K1 - for(int j = 0; j < K0; j++) - { - for(int i = 0; i < N; i++) - { - for(int jj = 0; jj < K1; jj++) - { - tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); - } - } - } -} - -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - -template -float gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s, - int num_cu, - int occupancy); + typename CDEElementWise = ck_tile::element_wise::PassThrough, + ck_tile::StreamKReductionStrategy ReductionStrategy> +std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s); template -float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - ck_tile::StreamKReductionStrategy reduction_strategy, - uint32_t num_sk_blocks, - int num_cu, - int occupancy) +std::tuple invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + int n_warmup, + int n_repeat, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy, + uint32_t num_sk_blocks) { ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), @@ -201,63 +105,42 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, reduction_strategy, num_sk_blocks}; - float ave_time; + std::tuple ave_time_and_batch; + if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { - ave_time = - gemm(args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true}, - num_cu, - occupancy); + ave_time_and_batch = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); } - else + else /*Reduction*/ { - ave_time = - gemm(args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true}, - num_cu, - occupancy); + ave_time_and_batch = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); } - return ave_time; -} - -template -auto shuffle_b(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + return ave_time_and_batch; } template @@ -279,19 +162,46 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } +ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +{ + if(strategy == "atomic") + { + return ck_tile::StreamKReductionStrategy::Atomic; + } + else if(strategy == "reduction") + { + return ck_tile::StreamKReductionStrategy::Reduction; + } + else + { + throw std::runtime_error("Unsupported Stream-K reduction strategy !!!"); + } +} + template -int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, +int run_gemm_example_with_layouts(int argc, + char* argv[], const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - using AccDataType = typename GemmTypeConfig::AccDataType; + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + static_assert(!GemmConfig::Preshuffle, "Not implemented"); + static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); + static_assert(!GemmConfig::PermuteA, "Not implemented"); + static_assert(!GemmConfig::PermuteB, "Not implemented"); + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -301,17 +211,14 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - // ck_tile::index_t kbatch = 1; int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); + bool flush_cache = arg_parser.get_bool("flush_cache"); ck_tile::StreamKReductionStrategy reduction_strategy = - ck_tile::StreamKReductionStrategy::Atomic; + get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); uint32_t num_sk_blocks = static_cast(arg_parser.get_int("num_sk_blocks")); - // uint32_t num_sk_blocks = 0; - int num_cu = arg_parser.get_int("num_cu"); - int occupancy = arg_parser.get_int("occupancy"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -349,36 +256,34 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - float ave_time = invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - n_warmup, - n_repeat, - reduction_strategy, - num_sk_blocks, - num_cu, - occupancy); + auto [ave_time, num_wgs_per_tile] = invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + n_warmup, + n_repeat, + flush_cache, + reduction_strategy, + num_sk_blocks); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); @@ -390,44 +295,34 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C - << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name << " B_Type=" << DataTypeTraits::name << " C_Type=" << DataTypeTraits::name - << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " + << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; bool pass = true; - // memory on host to store gpu reference result + // Memory on host to store gpu reference result ck_tile::HostTensor c_m_n_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_ref.SetZero(); - if(arg_parser.get_int("v") == 1) + if(arg_parser.get_int("v") == 1) // Validate on the CPU { ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_ref); const float max_accumulated_value = *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( - K, 1, max_accumulated_value); + K, num_wgs_per_tile, max_accumulated_value); pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); } - else if(arg_parser.get_int("v") == 2) + else if(arg_parser.get_int("v") == 2) // Validate on the GPU { - if constexpr(std::is_same_v) - { - // Restore input for B for gpu reference - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - if constexpr(GemmConfig::Preshuffle) - { - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - - // memory on device to store gpu reference result + // Memory on device to store gpu reference result ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); c_m_n_gpu_buf_ref.SetZero(); @@ -448,7 +343,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, const float max_accumulated_value = *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( - K, 1, max_accumulated_value); + K, num_wgs_per_tile, max_accumulated_value); pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); } diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm.cpp index 92de421b38..5a479e7c9c 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm.cpp @@ -204,6 +204,13 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::fp8_t, ck_tile::half_t>(a_layout, b_layout, arg_parser); } + else if(data_type == "bf8") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::bf8_t, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp new file mode 100644 index 0000000000..4f2e7a06d8 --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -0,0 +1,205 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation) -> std::tuple { + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + // Function to clear the output C tensor results after each repetition of the kernel + auto clear_gemm_output = [&]() { + if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + std::function preprocess = clear_gemm_output; + + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + int num_wgs_per_tile = estimate_num_wgs_per_tile(kargs.tile_partitioner); + + return std::tuple{ave_time, num_wgs_per_tile}; + }; + + if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) + { + return Run(ck_tile::integral_constant{}); + } + else // We are using ck_tile::StreamKReductionStrategy::Reduction + { + return Run(ck_tile::integral_constant{}); + } +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported layouts."); + } + + return 0; +} + +template