From eb7f6177136173c8a6af539bffd915fddff293c4 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 4 Dec 2025 12:18:25 +0800 Subject: [PATCH 01/65] fp8 fmha async pipeline (#3339) * replace qr with async pipeline * Add fp8fp32 to DTYPE_BITS * Add kAlignmentRandVal to avoid compile fail * format --------- Co-authored-by: Thomas Ning --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 23 +++++++++++++------ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 ++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 360d6a7c78..17d4f6e1d7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -29,7 +29,15 @@ from codegen.cpp_symbol_map import ( from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8": 8, + "fp8bf16": 8, + "fp8fp32": 8, + "bf8": 8, +} K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} @@ -678,6 +686,7 @@ class KernelComponentFactoryGfx9: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip elif dtype in ["fp8fp32"]: @@ -742,8 +751,8 @@ class KernelComponentFactoryGfx9: get_mask_map(mask_impl).keys(), ["no"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO None @@ -958,7 +967,7 @@ def get_fwd_blobs( cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -967,7 +976,7 @@ def get_fwd_blobs( cond &= mode == "group" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue # aiter::mha_fwd C++ api integration @@ -975,13 +984,13 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16", "fp8bf16"] cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue elif receipt == 888: cond = dtype in ["fp8bf16", "fp8fp32"] cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 27776453f6..2102fe768f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -87,6 +87,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; From ffc3120f63135cc697e46031523e44c5cd5d43fa Mon Sep 17 00:00:00 2001 From: kensclin Date: Thu, 4 Dec 2025 14:07:23 +0800 Subject: [PATCH 02/65] Ck tile/gemm blockscale opt (#3227) * GEMM block scale optimization kernel * GEMM block scale optimization kernel * Fix: Apply clang-format for style consistency * Fix: Apply clang-format for style consistency --------- Co-authored-by: Thomas Ning --- .../38_block_scale_gemm/gemm_utils.hpp | 1 + .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 98 ++++++++++++++----- 2 files changed, 75 insertions(+), 24 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 116661c157..2b2333b04c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -211,6 +211,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 2; }; template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 59a5b0df4e..d83338fbb2 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV using Base::m_preload; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // clang-format on } + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + QuantGroupSize::kK * QuantGroupSize::kK), + VectorLoadSize); + constexpr index_t kLdsVec = 8; + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; + constexpr index_t ds_write_inst = Aload_inst; + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto j_inst) { + ignore = j_inst; + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + static constexpr bool PreshuffleB = Problem::PreshuffleB; static constexpr auto TailNum = Problem::TailNum; @@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; __builtin_amdgcn_sched_barrier(0); @@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV __builtin_amdgcn_sched_barrier(0); // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; + index_t iCounter = (num_loop - 1) / loop_count; + while(iCounter > 0) { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + bq_block_tile, + a_warp_windows_ping); // prefetch B(2i+1) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // GEMM 2i - block_weight_preshuffle(c_block_tile, - a_warp_tensor, - b_warp_tensor_ping, - bq_block_tile, - a_warp_windows_ping); - static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::HotLoopScheduler(); // Next K @@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); - Base::HotLoopScheduler(); - iCounter--; + HotLoopScheduler(); } // tail @@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::Last2ndHotLoopScheduler(); - // GEMM loopK block_weight_preshuffle(c_block_tile, a_warp_tensor, b_warp_tensor_pong, bq_block_tile_2, a_warp_windows_pong); - Base::LastHotLoopScheduler(); + HotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { From 583fafc803a0ec9d0edc902fc6b9ecfdc42fb09b Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Wed, 3 Dec 2025 22:46:22 -0800 Subject: [PATCH 03/65] [CK_TILE] Fix for Moving DataTypeTraits into a Common File (#3335) This PR fixes a mismatch caused when PR #3146 was merged out of sync with develop, which made its intended changes ineffective. This PR reapplies those changes to move DataTypeTraits into a common file to mitigate code duplication. Co-authored-by: Thomas Ning --- .../gemm_multi_d_benchmark_single.cpp | 12 ++-- .../gemm_streamk_benchmark_single.cpp | 8 +-- .../ops/gemm_streamk/gemm_streamk_common.hpp | 59 +------------------ 3 files changed, 11 insertions(+), 68 deletions(-) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 41d2f736e1..25ac342f3e 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; - std::string dtype_d0 = DataTypeTraits::name; - std::string dtype_d1 = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; + std::string dtype_d0 = ck_tile::DataTypeTraits::name; + std::string dtype_d1 = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 13cadcd55a..5e88dc486a 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -83,10 +83,10 @@ void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 179aeb7307..15a3c91964 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -6,67 +6,10 @@ #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) From 9cb1f421bce29cb70bf7905858d2f8823f586621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Thu, 4 Dec 2025 12:58:31 +0200 Subject: [PATCH 04/65] [CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Separate layouts into separate entities for input, weight, and output tensors. * Add test for handling bias tensor layouts. * Use instance string in builder tests. * Add handling of output bias data types and layouts. * Generalize handling of the elementwise ops. * Test fix. * Create builder for layouts. * Layout builder improvements. * Improve layout builder. * Simplify bias layout handling. * Code clean-up. * Move layout utils into separate file. * Remove hard-coded layout combinations. * Small code clean-up. * Move data type utils into a separate file. * Add data types, layouts, and elementwise ops per conv tensor. * Builder bug fixes after refactoring. * Working baseline. * Make signature definition look nice in the test code. * Move TensorConfig into test implementations. * Fix all fwd conv builder tests. * Fix conv traits and descriptors tests. * More factory assets under a separate directory. * Fix building conv traits. * Fix clang-format. * Add Readme doc to describe the design. * Add link to main Readme. Fix links in the builder design doc. * Clean-up data type/layout/elementwise op conversions. * Switch from dimension and tensor type specific layouts to a flat list of tensor layouts. * Fix clang-formatting. * Fix clang-format for test code. * Simplify fwd conv signature definitions in the test code. * Remove accidental edits. * Fix comment string. * Fix instance factory after rebase. * Fix tests after rebase. * Unify layout handling. * Add more conv layout unit tests. * Clang-format. * Fix merge conflicts. * Improve elementwise op handling. --------- Co-authored-by: Ville Pietilä <> --- experimental/builder/README.md | 4 + .../builder/include/ck_tile/builder/README.md | 244 +++++++++++ .../builder/conv_signature_concepts.hpp | 149 ++++++- .../ck_tile/builder/conv_signature_utils.hpp | 47 --- .../builder/factory/conv_fwd_dl_factory.hpp | 9 +- .../factory/conv_fwd_large_tensor_factory.hpp | 9 +- .../builder/factory/conv_fwd_v3_factory.hpp | 9 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 9 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 9 +- .../factory/helpers/conv_elementwise_op.hpp | 80 +++- .../factory/helpers/conv_tensor_layout.hpp | 267 ++++++++---- .../factory/helpers/conv_tensor_type.hpp | 204 ++++++--- .../builder/reflect/conv_description.hpp | 13 +- .../ck_tile/builder/reflect/conv_traits.hpp | 71 ++-- .../builder/include/ck_tile/builder/types.hpp | 220 +++++----- experimental/builder/test/CMakeLists.txt | 1 + .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 23 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 23 +- .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 20 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 40 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 46 ++ .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 42 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 22 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 22 +- .../test/conv/test_ckb_conv_fwd_2d_fp8.cpp | 21 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 40 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 23 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 23 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 23 +- .../builder/test/conv/test_conv_traits.cpp | 15 +- .../test/impl/conv_signature_types.hpp | 40 +- .../builder/test/test_conv_description.cpp | 79 +++- .../builder/test/unit_conv_elementwise_op.cpp | 38 +- .../builder/test/unit_conv_tensor_layout.cpp | 397 +++++++++++++++++- .../builder/test/unit_conv_tensor_type.cpp | 61 +-- .../test/utils/ckb_conv_test_configs.hpp | 3 + .../test/utils/ckb_conv_test_utils.hpp | 2 +- 37 files changed, 1731 insertions(+), 617 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/README.md delete mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/README.md b/experimental/builder/README.md index aa7c7d969d..18e9e58739 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -10,6 +10,10 @@ The builder provides a high-level, semantically-clear interface for constructing This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CKTile, but is currently limited to formalizing the interface between MIOpen and CK. +## Design descriptions + +- [CK Builder design description](include/ck_tile/builder/README.md) + ## Directory Structure - `include/ck_tile/builder/` diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md new file mode 100644 index 0000000000..a0522a50d6 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -0,0 +1,244 @@ +# Composable Kernel Builder Design Documentation + +This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations. + +## Table of Contents + +- [Convolution Signature Design](#convolution-signature-design) + - [Overview](#overview) + - [Architecture](#architecture) + - [Core Components](#core-components) + - [Concepts and Validation](#concepts-and-validation) +--- + +## Convolution Signature Design + +### Overview + +The convolution signature system provides a **compile-time description** of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling: + +- **Compile-time validation**: Ensures type safety and correctness before kernel instantiation +- **Kernel selection**: Matches user requirements to optimized implementations +- **Specialization**: Enables optimized code paths for specific configurations +- **Composability**: Supports building complex operations from simpler components + +The signature leverages modern C++20 features, particularly **concepts**, to provide expressive, self-documenting interfaces with compile-time guarantees. + +### Architecture + +The signature system is organized into a hierarchical structure: + +``` +┌─────────────────────────────────────────────────────────┐ +│ ConvSignature │ +├─────────────────────────────────────────────────────────┤ +│ Properties: │ +│ • spatial_dim: int (1D, 2D, or 3D) │ +│ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │ +│ • data_type: DataType (default data type) │ +│ • accumulation_data_type: DataType │ +│ • input: ConvTensor ──┐ │ +│ • weight: ConvTensor ──│ │ +│ • output: ConvTensor ──│ │ +└──────────────────────────────────┼──────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────┐ + │ ConvTensor │ + ├─────────────────────────────────────────┤ + │ ╔═════════════════════════════════════╗ │ + │ ║ TensorConfig (required) ║ │ + │ ╠═════════════════════════════════════╣ │ + │ ║ • layout: ConvLayout ║ │ + │ ║ • data_type: DataType (optional) ║ │ + │ ║ • compute_type: DataType (optional)║ │ + │ ╚═════════════════════════════════════╝ │ + │ │ + │ ┌─────────────────────────────────────┐ │ + │ │ TensorOperation (optional) │ │ + │ ├─────────────────────────────────────┤ │ + │ │ • elementwise_operation │ │ + │ │ • auxiliary_operand_configs[] │ │ + │ │ (each is also ConvTensor) ◄───────┼─┐ + │ └─────────────────────────────────────┘ │ │ + └─────────────────────────────────────────┘ │ + │ + Recursive ───────────────┘ +``` +Key Design Points: + - ConvSignature contains three ConvTensor instances (input, weight, output) + - All tensors share the same ConvTensor structure + - Each ConvTensor has: + - TensorConfig (required): Defines layout as well as optional data and compute type overrides + - TensorOperation (optional): Defines fused elementwise operations + - Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type + +### Core Components + +#### 1. Signature Level + +The top-level signature contains global properties that apply to the entire convolution operation: + +```cpp +template +concept ConvSignatureDescriptor = requires(T t) { + { t.spatial_dim } -> std::convertible_to; // 1, 2, or 3 + { t.data_type } -> std::convertible_to; // Default data type + { t.input } -> ConvTensorDescriptor; + { t.weight } -> ConvTensorDescriptor; + { t.output } -> ConvTensorDescriptor; + requires ConvolutionDirectionWellDefinedIfProvided; // Optional direction +}; +``` + +**Properties:** +- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D) +- **`direction`**: Operation type (optional, defaults to FORWARD) + - `FORWARD`: Standard forward convolution + - `BACKWARD_DATA`: Gradient computation w.r.t. input + - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights +- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8) +- **`accumulation_data_type`**: Type used for internal accumulation + +#### 2. Tensor Level + +Each tensor (input, weight, output) has its own descriptor: + +```cpp +template +concept ConvTensorDescriptor = requires(T t) { + { t.config } -> TensorConfigDescriptor; + requires ElementwiseOpWellDefinedIfProvided; +}; +``` + +A tensor descriptor encapsulates: +- **Configuration**: Layout and data type information +- **Operation** (optional): Fused elementwise operations on this tensor + +#### 3. Tensor Configuration + +Describes the memory layout and data types: + +```cpp +template +concept TensorConfigDescriptor = requires(T t) { + { t.layout } -> std::convertible_to; + { t.data_type } -> std::convertible_to; // Optional override +}; +``` + +**Layout Types** (dimension-specific): +- **1D Convolution**: + - Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided` + - Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided` + - Output: `GNKW`, `GNWK`, `NWGK`, `NGKW`, `G_NW_K_strided` + +- **2D Convolution**: + - Input: `GNCHW`, `GNHWC`, `NHWGC`, `NGCHW`, `G_NHW_C_strided` + - Weight: `GKYXC`, `GKCYX`, `KYXGC`, `G_K_YX_C_strided` + - Output: `GNKHW`, `GNHWK`, `NHWGK`, `NGKHW`, `G_NHW_K_strided` + +- **3D Convolution**: + - Input: `GNCDHW`, `GNDHWC`, `NDHWGC`, `NGCDHW`, `G_NDHW_C_strided` + - Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided` + - Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided` + +Where: +- `G` = Groups +- `N` = Batch size +- `C` = Input channels +- `K` = Output channels (filters) +- `W`, `H`, `D` = Width, Height, Depth (spatial dimensions) +- `X`, `Y`, `Z` = Filter dimensions + +#### 4. Tensor Operations + +Describes fused elementwise operations applied to a tensor: + +```cpp +template +concept TensorOperatorDescriptor = requires(T t) { + { t.elementwise_operation } -> std::convertible_to; + requires AuxiliaryOperandConfigsWellDefinedIfProvided; +}; +``` + +**Supported Operations:** +- `PASS_THROUGH`: No operation (identity) +- `SCALE`: Multiply by a scalar +- `CLAMP`: Clamp values to a range +- `BIAS_BNORM_CLAMP`: Bias addition + batch normalization + clamp +- `SCALEADD_SCALEADD_RELU`: Fused scale-add operations + ReLU activation + +**Auxiliary Operands:** +Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through `auxiliary_operand_configs`, which is an array of `TensorConfigDescriptor` objects describing the layout and data type of each auxiliary input. + +### Concepts and Validation + +The signature system uses C++20 concepts for compile-time validation at multiple levels: + +#### Constraint Concepts + +```cpp +// Spatial dimension must be 1, 2, or 3 +template +concept ConvSpatialDim = std::is_integral_v && (N == 1 || N == 2 || N == 3); + +// Valid data types for convolution +template +concept ValidConvDataType = + (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || + (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +``` + +#### Validation Concept + +```cpp +// Validates a complete signature +template +concept ValidConvSignature = requires { + requires ConvSpatialDim; + requires ValidConvDataType; +}; +``` + +#### Tensor Descriptors + +The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows: +- **Flexibility**: Each tensor can have independent layout and data type +- **Reusability**: Common configurations can be shared across different signatures +- **Extensibility**: New properties can be added to specific levels without affecting others +- **Clarity**: Separates concerns (global properties vs. tensor-specific properties) + +#### Optional Signature Fields + +Several fields in the signature are optional: +- **`direction`**: Defaults to `FORWARD` if not specified, reducing boilerplate for the common case +- **Tensor `data_type`**: Falls back to signature's default, allowing mixed-precision with minimal specification +- **Tensor `operation`**: Defaults to `PASS_THROUGH`, supporting both fused and non-fused operations with the same interface + +This design follows the principle of "make the common case simple, the complex case possible." + +#### Union-Based Layout Representation + +The `ConvLayout` type uses unions to support dimension-agnostic code: + +```cpp +struct ConvLayout { + union { + ConvInputLayout _input_layout; + ConvWeightLayout _weight_layout; + ConvOutputLayout _output_layout; + ConvAuxiliaryTensorLayout _aux_tensor_layout; + }; + // ... constructors for each type +}; +``` + +This allows: +- Single type to represent all layout variants +- Type-safe construction through overloaded constructors +- Compile-time enforcement of valid combinations through concepts + +--- diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 05575590c4..8dc92c6bef 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -28,24 +28,104 @@ namespace ck_tile::builder { template concept ConvSpatialDim = std::is_integral_v && (N == 1 || N == 2 || N == 3); -// Constraints for forward convolution layouts. -template -concept ValidConvLayoutForSpatialDim = - (SpatialDim == 1 && std::same_as) || - (SpatialDim == 2 && std::same_as) || - (SpatialDim == 3 && std::same_as); - // Constrains convolution data types to common floating-point types. template -concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || - (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +concept ValidConvDataType = + (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || + (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); + +template +concept BiasTensorLayout = + (L == TensorLayout::GC) || (L == TensorLayout::G_C_strided) || (L == TensorLayout::G_K_strided); + +template +concept ConvInputLayout1D = + (L == TensorLayout::GNCW) || (L == TensorLayout::GNWC) || (L == TensorLayout::NWGC) || + (L == TensorLayout::NGCW) || (L == TensorLayout::G_NW_C_strided); + +template +concept ConvInputLayout2D = + (L == TensorLayout::GNCHW) || (L == TensorLayout::GNHWC) || (L == TensorLayout::NHWGC) || + (L == TensorLayout::NGCHW) || (L == TensorLayout::G_NHW_C_strided); + +template +concept ConvInputLayout3D = + (L == TensorLayout::GNCDHW) || (L == TensorLayout::GNDHWC) || (L == TensorLayout::NDHWGC) || + (L == TensorLayout::NGCDHW) || (L == TensorLayout::G_NDHW_C_strided); + +template +concept ConvWeightLayout1D = (L == TensorLayout::GKXC) || (L == TensorLayout::GKCX) || + (L == TensorLayout::KXGC) || (L == TensorLayout::G_K_X_C_strided); + +template +concept ConvWeightLayout2D = (L == TensorLayout::GKYXC) || (L == TensorLayout::GKCYX) || + (L == TensorLayout::KYXGC) || (L == TensorLayout::G_K_YX_C_strided); + +template +concept ConvWeightLayout3D = (L == TensorLayout::GKZYXC) || (L == TensorLayout::GKCZYX) || + (L == TensorLayout::KZYXGC) || (L == TensorLayout::G_K_ZYX_C_strided); + +template +concept ConvOutputLayout1D = + (L == TensorLayout::GNKW) || (L == TensorLayout::GNWK) || (L == TensorLayout::NWGK) || + (L == TensorLayout::NGKW) || (L == TensorLayout::G_NW_K_strided); + +template +concept ConvOutputLayout2D = + (L == TensorLayout::GNKHW) || (L == TensorLayout::GNHWK) || (L == TensorLayout::NHWGK) || + (L == TensorLayout::NGKHW) || (L == TensorLayout::G_NHW_K_strided); + +template +concept ConvOutputLayout3D = + (L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) || + (L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided); template -concept ConvLayout = std::same_as, GroupConvLayout>; +concept TensorConfigDescriptor = requires(T t) { + { t.layout } -> std::convertible_to; + // Only require that data type is defined. It might be set to undefined value, in which case the + // signature's data type is used. + { t.data_type } -> std::convertible_to; +}; template -concept HasElementwiseOp = requires(T t) { - { t.elementwise_operation }; +concept HasAuxiliaryOperandConfigs = requires(T t) { + { t.auxiliary_operand_configs }; +}; + +namespace detail { +template +struct IsArrayOfTensorConfigDescriptors : std::false_type +{ +}; + +template + requires TensorConfigDescriptor +struct IsArrayOfTensorConfigDescriptors> : std::true_type +{ +}; +} // namespace detail + +template +concept ConvertibleToArrayOfTensorConfigs = + detail::IsArrayOfTensorConfigDescriptors>::value; + +template +concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) { + requires !HasAuxiliaryOperandConfigs || requires { + { t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs; + }; +}; + +template +concept TensorOperatorDescriptor = requires(T t) { + { t.elementwise_operation } -> std::convertible_to; + requires AuxiliaryOperandConfigsWellDefinedIfProvided; +}; + +template +concept HasTensorOp = requires(T t) { + { t.operation }; }; template @@ -56,11 +136,8 @@ concept HasConvolutionDirection = requires(T t) { // Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well // defined template -concept ElementwiseOpWellDefinedIfProvided = requires(T t) { - requires !HasElementwiseOp || requires { - { t.elementwise_operation } -> std::convertible_to; - }; -}; +concept ElementwiseOpWellDefinedIfProvided = + !HasTensorOp || requires(T t) { requires TensorOperatorDescriptor; }; // Note: it is not required to provide a convolution, but if one is provided, check if well defined template @@ -70,13 +147,27 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) { }; }; +// Concept for the convolution tensor +template +concept ConvTensorDescriptor = requires(T t) { + { t.config } -> TensorConfigDescriptor; + requires ElementwiseOpWellDefinedIfProvided; +}; + +template +concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) { + requires HasTensorOp; + requires HasAuxiliaryOperandConfigs; +}; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; - { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; - requires ElementwiseOpWellDefinedIfProvided; + { t.input } -> ConvTensorDescriptor; + { t.weight } -> ConvTensorDescriptor; + { t.output } -> ConvTensorDescriptor; requires ConvolutionDirectionWellDefinedIfProvided; }; @@ -84,7 +175,7 @@ concept ConvSignatureDescriptor = requires(T t) { template concept ValidConvSignature = requires { requires ConvSpatialDim; - requires ConvDataType; + requires ValidConvDataType; }; // Predicate for forward convolution (default if direction is not included). @@ -100,4 +191,22 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_ template concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); +// Constraints for forward convolution input layouts. +template +concept ValidConvInputLayoutForSpatialDim = + (SpatialDim == 1 && ConvInputLayout1D) || (SpatialDim == 2 && ConvInputLayout2D) || + (SpatialDim == 3 && ConvInputLayout3D); + +// Constraints for forward convolution output layouts. +template +concept ValidConvOutputLayoutForSpatialDim = + (SpatialDim == 1 && ConvOutputLayout1D) || (SpatialDim == 2 && ConvOutputLayout2D) || + (SpatialDim == 3 && ConvOutputLayout3D); + +// Constraints for forward convolution weight layouts. +template +concept ValidConvWeightLayoutForSpatialDim = + (SpatialDim == 1 && ConvWeightLayout1D) || (SpatialDim == 2 && ConvWeightLayout2D) || + (SpatialDim == 3 && ConvWeightLayout3D); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp deleted file mode 100644 index 65a4b60588..0000000000 --- a/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/builder/types.hpp" - -namespace ck_tile::builder { -/********************************************** - * constexpr helper functions for optional parameters - **********************************************/ - -template -concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; }; - -template -concept ProvidesConvolutionDirection = requires { Sig.direction; }; - -template -constexpr auto get_elementwise_operation() -{ - if constexpr(ProvidesElementwiseOperation) - { - return Sig.elementwise_operation; - } - else - { - return ElementwiseOperation::PASS_THROUGH; - } -} - -template -constexpr auto get_conv_direction() -{ - if constexpr(ProvidesConvolutionDirection) - { - return Sig.direction; - } - else - { - return ConvDirection::FORWARD; - } -} -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index dee918cc1f..0c675ac7f1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,7 +7,6 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/conv_signature_utils.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" @@ -25,11 +24,9 @@ template ()); - using Types = internal::ConvTensorTypes; - using Ops = internal::ElementwiseOps()>; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 383ecbf8c9..98e368ca61 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,7 +8,6 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/conv_signature_utils.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" @@ -27,11 +26,9 @@ template ()); - using Types = internal::ConvTensorTypes; - using Ops = internal::ElementwiseOps()>; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 90d4abe3e7..79955a1f44 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,7 +8,6 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/conv_signature_utils.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" @@ -27,11 +26,9 @@ template ()); - using Types = internal::ConvTensorTypes; - using Ops = internal::ElementwiseOps()>; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index e35b3f3d46..fcce46aea7 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,7 +8,6 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/conv_signature_utils.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" @@ -27,11 +26,9 @@ template ()); - using Types = internal::ConvTensorTypes; - using Ops = internal::ElementwiseOps()>; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index fc5b32f799..df7fb25168 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,7 +8,6 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/conv_signature_utils.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" #include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" @@ -27,11 +26,9 @@ template ()); - using Types = internal::ConvTensorTypes; - using Ops = internal::ElementwiseOps()>; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::FwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp index 4a13f4e508..a39cd7410b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp @@ -6,32 +6,70 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck_tile/builder/builder_utils.hpp" #include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::factory::internal { -template +template +struct ElementwiseOpToCK +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCK +{ + using Op = ck::tensor_operation::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCK +{ + using Op = ck::tensor_operation::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCK +{ + using Op = ck::tensor_operation::element_wise::Clamp; +}; + +template <> +struct ElementwiseOpToCK +{ + using Op = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; +}; + +template <> +struct ElementwiseOpToCK +{ + using Op = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; +}; + +template +consteval auto GetElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCK{}; + } + else + { + return ElementwiseOpToCK{}; + } +} + +template struct ElementwiseOps { - // This will trigger if a specialization for the given DataType is not found. - // We should always catch this in an earlier validation check. - static_assert(sizeof(UnsupportedEnumValue) == 0, - "Internal error. Unsupported elementwise operation for convolution factory."); -}; - -template <> -struct ElementwiseOps -{ - using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough; -}; - -template <> -struct ElementwiseOps -{ - using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale; + static constexpr auto input_op = GetElementwiseOp(); + static constexpr auto weight_op = GetElementwiseOp(); + static constexpr auto output_op = GetElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp index b3effa782e..a6c0b48c54 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp @@ -6,141 +6,228 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/utility/tuple.hpp" #include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/builder_utils.hpp" namespace ck_tile::builder::factory::internal { -// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types. -template - requires(ConvSpatialDim && ValidConvLayoutForSpatialDim) -struct ConvTensorLayouts +template +struct LayoutToCK { - // This will trigger if a specialization for the given layout is not found. - // We should always catch this in an earlier validation check. - using Layout = decltype(LayoutValue); - static_assert(sizeof(Layout) == 0, - "Internal error. Unsupported layout for convolution factory."); + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); }; -// 1D Forward Convolution Layout Specializations +// Bias layouts template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NWGC; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NWGK; + using type = ck::tensor_layout::convolution::G_K; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NGCW; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKW; + using type = ck::tensor_layout::convolution::GC; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::GNWC; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNWK; + using type = ck::tensor_layout::convolution::G_C; }; +// Input 1D template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NGCW; - using BLayout = ck::tensor_layout::convolution::GKCX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKW; + using type = ck::tensor_layout::convolution::NWGC; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NGCHW; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKHW; + using type = ck::tensor_layout::convolution::NGCW; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NHWGC; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NHWGK; + using type = ck::tensor_layout::convolution::GNWC; }; +// Input 2D template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::GNHWC; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNHWK; + using type = ck::tensor_layout::convolution::NGCHW; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NGCHW; - using BLayout = ck::tensor_layout::convolution::GKCYX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKHW; + using type = ck::tensor_layout::convolution::NHWGC; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NGCDHW; - using BLayout = ck::tensor_layout::convolution::GKCZYX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKDHW; + using type = ck::tensor_layout::convolution::GNHWC; }; +// Input 3D template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::NDHWGC; - using BLayout = ck::tensor_layout::convolution::GKZYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NDHWGK; + using type = ck::tensor_layout::convolution::NGCDHW; }; - template <> -struct ConvTensorLayouts +struct LayoutToCK { - using ALayout = ck::tensor_layout::convolution::GNDHWC; - using BLayout = ck::tensor_layout::convolution::GKZYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNDHWK; + using type = ck::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GNDHWC; }; -template -consteval auto GetTensorLayout() +// Weight 1D +template <> +struct LayoutToCK { + using type = ck::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GKCX; +}; - if constexpr(SPATIAL_DIM == 1) - { - return internal::ConvTensorLayouts{}; - } - else if constexpr(SPATIAL_DIM == 2) - { - return internal::ConvTensorLayouts{}; - } - else if constexpr(SPATIAL_DIM == 3) - { - return internal::ConvTensorLayouts{}; - } - else - { - static_assert(false, "Unsupported spatial dimension for convolution layout."); - } +// Weight 2D +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NGKW; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NGKHW; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NGKDHW; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCK +{ + using type = ck::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCK() +{ + return typename LayoutToCK::type{}; } +struct EmptyAuxiliaryTensorLayout +{ + using type = ck::Tuple<>; +}; + +template +consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence) +{ + return ck::Tuple< + decltype(TensorLayoutToCK())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTensorLayouts +{ + static constexpr auto Size = AuxiliaryTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTensorLayouts() +{ + return AuxiliaryTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTensorLayouts() +{ + return EmptyAuxiliaryTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct ConvTensorLayouts +{ + static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported."); + using ALayout = decltype(TensorLayoutToCK()); + using BLayout = decltype(TensorLayoutToCK()); + using ELayout = decltype(TensorLayoutToCK()); + using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp index d8a8eb5da0..81de2140f2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp @@ -6,82 +6,172 @@ #include "ck/utility/data_type.hpp" #include "ck_tile/builder/types.hpp" #include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::factory::internal { -// Type mappings from builder convolution data type to CK tensor types. -template -struct ConvTensorTypes +template +struct DataTypeToCK { - // This will trigger if a specialization for the given DataType is not found. - // We should always catch this in an earlier validation check. - static_assert(sizeof(UnsupportedEnumValue) == 0, - "Internal error. Unsupported data type for convolution factory."); + // Catch unsupported data types at compile time + static_assert(sizeof(UnsupportedEnumValue
) == 0, "Unsupported data type conversion to CK."); }; template <> -struct ConvTensorTypes +struct DataTypeToCK { - using ADataType = ck::half_t; - using AComputeType = ck::half_t; - using BDataType = ck::half_t; - using BComputeType = ck::half_t; - using CShuffleDataType = ck::half_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::half_t; + using type = ck::half_t; +}; +template <> +struct DataTypeToCK +{ + using type = ck::bhalf_t; +}; +template <> +struct DataTypeToCK +{ + using type = float; +}; +template <> +struct DataTypeToCK +{ + using type = int32_t; +}; +template <> +struct DataTypeToCK +{ + using type = int8_t; +}; +template <> +struct DataTypeToCK +{ + using type = ck::f8_t; }; -template <> -struct ConvTensorTypes +struct CK_empty_tuple { - using ADataType = ck::bhalf_t; - using AComputeType = ck::bhalf_t; - using BDataType = ck::bhalf_t; - using BComputeType = ck::bhalf_t; - using CShuffleDataType = ck::bhalf_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::bhalf_t; + using type = ck::Tuple<>; }; -template <> -struct ConvTensorTypes +template +consteval auto ConvertDataTypeToCK() { - using ADataType = float; - using AComputeType = float; - using BDataType = float; - using BComputeType = float; - using CShuffleDataType = float; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = float; + return DataTypeToCK
{}; +} + +template +consteval auto GetTensorDataAndComputeTypes() +{ + constexpr auto data_type = Config.data_type; + constexpr auto compute_type = Config.compute_type; + + if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED) + { + return std::make_pair(ConvertDataTypeToCK(), + ConvertDataTypeToCK()); + } + else if constexpr(data_type == DataType::UNDEFINDED) + { + return std::make_pair(ConvertDataTypeToCK(), + ConvertDataTypeToCK()); + } + else if constexpr(compute_type == DataType::UNDEFINDED) + { + return std::make_pair(ConvertDataTypeToCK(), + ConvertDataTypeToCK()); + } + else + { + return std::make_pair(ConvertDataTypeToCK(), + ConvertDataTypeToCK()); + } +} + +template +consteval auto GetTensorAccumulationType() +{ + constexpr auto data_type = SignatureAccDataType; + if constexpr(data_type == DataType::UNDEFINDED) + { + return ConvertDataTypeToCK(); + } + else + { + return ConvertDataTypeToCK(); + } +} + +template +consteval auto GetAuxiliaryTensorDataTypeValue() +{ + constexpr auto data_type = Config.data_type; + if constexpr(data_type == DataType::UNDEFINDED) + { + return ConvertDataTypeToCK(); + } + else + { + return ConvertDataTypeToCK(); + } +} + +template +consteval auto GetAuxiliaryTensorDataTypeTuple(std::index_sequence) +{ + return ck::Tuple< + typename decltype(GetAuxiliaryTensorDataTypeValue())::type...>{}; +} + +template +struct AuxiliaryTensorDataTypes +{ + static constexpr auto Size = AuxiliaryTensorConfigsValue.size(); + using type = + decltype(GetAuxiliaryTensorDataTypeTuple( + std::make_index_sequence{})); }; -template <> -struct ConvTensorTypes +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTensorDataTypes() { - using ADataType = int8_t; - using AComputeType = int8_t; - using BDataType = int8_t; - using BComputeType = int8_t; - using CShuffleDataType = int8_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = int32_t; - using EDataType = int8_t; -}; + return AuxiliaryTensorDataTypes{}; +} -template <> -struct ConvTensorTypes +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTensorDataTypes() { - using ADataType = ck::f8_t; - using AComputeType = ck::f8_t; - using BDataType = ck::f8_t; - using BComputeType = ck::f8_t; - using CShuffleDataType = ck::f8_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::f8_t; + return CK_empty_tuple{}; +} + +template +struct FwdConvTensorDataTypes +{ + static constexpr auto input_types = + GetTensorDataAndComputeTypes(); + static constexpr auto weight_types = + GetTensorDataAndComputeTypes(); + static constexpr auto output_types = + GetTensorDataAndComputeTypes(); + + using ADataType = typename decltype(input_types.first)::type; + using AComputeType = typename decltype(input_types.second)::type; + using BDataType = typename decltype(weight_types.first)::type; + using BComputeType = typename decltype(weight_types.second)::type; + using AccDataType = + typename decltype(GetTensorAccumulationType())::type; + using EDataType = typename decltype(output_types.first)::type; + + // This is the "compute" type for output. + using CShuffleDataType = typename decltype(output_types.second)::type; + + // Data types for the auxiliary tensors (e.g., bias). + using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index be3c208ba8..261c3f103d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -41,8 +41,9 @@ struct ConvSignatureInfo { int spatial_dim; builder::ConvDirection direction; - std::variant - layout; + builder::TensorLayout input_layout; + builder::TensorLayout weight_layout; + builder::TensorLayout output_layout; builder::DataType data_type; builder::ElementwiseOperation input_element_op; builder::ElementwiseOperation weight_element_op; @@ -106,7 +107,9 @@ class ConvDescription : public Description f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel"); f.writeLine(1, "Signature"); f.writeLine(2, "Tensor Type: ", signature_.data_type); - f.writeLine(2, "Memory Layout: ", signature_.layout); + f.writeLine(2, "Input Layout: ", signature_.input_layout); + f.writeLine(2, "Weight Layout: ", signature_.weight_layout); + f.writeLine(2, "Output Layout: ", signature_.output_layout); f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op); f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op); f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op); @@ -264,7 +267,9 @@ conv::ConvDescription describe() conv::ConvSignatureInfo{ .spatial_dim = Traits::spatial_dim, .direction = Traits::direction, - .layout = Traits::layout, + .input_layout = Traits::layout[0], + .weight_layout = Traits::layout[1], + .output_layout = Traits::layout[2], .data_type = Traits::data_type, .input_element_op = Traits::input_element_op, .weight_element_op = Traits::weight_element_op, diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 316f570bcd..29ac49e549 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -298,7 +298,10 @@ constexpr auto conv_spec() /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. -/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout template constexpr auto conv_layout() { @@ -314,22 +317,30 @@ constexpr auto conv_layout() if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + return std::array{builder::TensorLayout::GNWC, + builder::TensorLayout::GKXC, + builder::TensorLayout::GNWK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + return std::array{builder::TensorLayout::NWGC, + builder::TensorLayout::GKXC, + builder::TensorLayout::NWGK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + return std::array{builder::TensorLayout::NGCW, + builder::TensorLayout::GKXC, + builder::TensorLayout::NGKW}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + return std::array{builder::TensorLayout::NGCW, + builder::TensorLayout::GKCX, + builder::TensorLayout::NGKW}; } } else if constexpr(InstTraits::kSpatialDim == 2) @@ -337,25 +348,33 @@ constexpr auto conv_layout() if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + return std::array{builder::TensorLayout::GNHWC, + builder::TensorLayout::GKYXC, + builder::TensorLayout::GNHWK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + return std::array{builder::TensorLayout::NHWGC, + builder::TensorLayout::GKYXC, + builder::TensorLayout::NHWGK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + return std::array{builder::TensorLayout::NGCHW, + builder::TensorLayout::GKYXC, + builder::TensorLayout::NGKHW}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + return std::array{builder::TensorLayout::NGCHW, + builder::TensorLayout::GKCYX, + builder::TensorLayout::NGKHW}; } } else if constexpr(InstTraits::kSpatialDim == 3) @@ -363,25 +382,33 @@ constexpr auto conv_layout() if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + return std::array{builder::TensorLayout::GNDHWC, + builder::TensorLayout::GKZYXC, + builder::TensorLayout::GNDHWK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + return std::array{builder::TensorLayout::NDHWGC, + builder::TensorLayout::GKZYXC, + builder::TensorLayout::NDHWGK}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + return std::array{builder::TensorLayout::NGCDHW, + builder::TensorLayout::GKZYXC, + builder::TensorLayout::NGKDHW}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + return std::array{builder::TensorLayout::NGCDHW, + builder::TensorLayout::GKCZYX, + builder::TensorLayout::NGKDHW}; } } } @@ -433,22 +460,10 @@ template constexpr builder::ElementwiseOperation elementwise_op() { constexpr std::string_view name = detail::elementwise_op_name(); - if constexpr(detail::case_insensitive_equal(name, "Bias")) - { - return builder::ElementwiseOperation::BIAS; - } - else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) - { - return builder::ElementwiseOperation::BIAS_CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) { return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; } - else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) - { - return builder::ElementwiseOperation::BILINEAR; - } else if constexpr(detail::case_insensitive_equal(name, "Clamp")) { return builder::ElementwiseOperation::CLAMP; @@ -461,6 +476,10 @@ constexpr builder::ElementwiseOperation elementwise_op() { return builder::ElementwiseOperation::PASS_THROUGH; } + else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + { + return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU; + } } /// @brief Derives a gemm padding from a kernel instance type. diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 1aeb71af10..565bb98528 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -6,64 +6,91 @@ #include #include #include +#include +#include namespace ck_tile::builder { enum class DataType { + UNDEFINDED = 0, FP32, FP16, BF16, FP8, + INT32, I8, U8 }; -// Memory layouts for 1D convolution tensors. -// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width -// Enum defines Input, Weight, and Output tensor layouts respectively. -enum class GroupConvLayout1D +enum class TensorLayout { - GNWC_GKXC_GNWK, - NWGC_GKXC_NWGK, - NGCW_GKXC_NGKW, - NGCW_GKCX_NGKW -}; + UNDEFINED, -// Memory layouts for 2D convolution tensors. -// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height -// Enum defines Input, Weight, and Output tensor layouts respectively. -enum class GroupConvLayout2D -{ - GNHWC_GKYXC_GNHWK, - NHWGC_GKYXC_NHWGK, - NGCHW_GKYXC_NGKHW, - NGCHW_GKCYX_NGKHW -}; + // Bias tensors + GC, + G_C_strided, + G_K_strided, -// Memory layouts for 3D convolution tensors. -// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth, -// H: Height Enum defines Input, Weight, and Output tensor layouts respectively. -enum class GroupConvLayout3D -{ - GNDHWC_GKZYXC_GNDHWK, - NDHWGC_GKZYXC_NDHWGK, - NGCDHW_GKZYXC_NGKDHW, - NGCDHW_GKCZYX_NGKDHW, -}; + // 1D conv input tensor + GNCW, + GNWC, + NWGC, + NGCW, + G_NW_C_strided, -struct GroupConvLayout -{ - union - { - GroupConvLayout1D _1d; - GroupConvLayout2D _2d; - GroupConvLayout3D _3d; - }; + // 2D conv input tensor + GNCHW, + GNHWC, + NHWGC, + NGCHW, + G_NHW_C_strided, - constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} - constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} - constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} + // 3D conv input tensor + GNCDHW, + GNDHWC, + NDHWGC, + NGCDHW, + G_NDHW_C_strided, + + // 1D conv weight tensor + GKXC, + GKCX, + KXGC, + G_K_X_C_strided, + + // 2D conv weight tensor + GKYXC, + GKCYX, + KYXGC, + G_K_YX_C_strided, + + // 3D conv weight tensor + GKZYXC, + GKCZYX, + KZYXGC, + G_K_ZYX_C_strided, + + // 1D conv output tensor + GNKW, + GNWK, + NWGK, + NGKW, + G_NW_K_strided, + + // 2D conv output tensor + GNKHW, + GNHWK, + NHWGK, + NGKHW, + G_NHW_K_strided, + + // 3D conv output tensor + GNKDHW, + GNDHWK, + NDHWGK, + NGKDHW, + G_NDHW_K_strided }; // Direction of the convolution operation. @@ -77,13 +104,11 @@ enum class ConvDirection // Fused element-wise operations. enum class ElementwiseOperation { - BIAS, - BIAS_CLAMP, BIAS_BNORM_CLAMP, - BILINEAR, - CLAMP, SCALE, - PASS_THROUGH + CLAMP, + PASS_THROUGH, + SCALEADD_SCALEADD_RELU }; // Enums for pipeline versions & schedulers @@ -188,8 +213,10 @@ inline std::ostream& operator<<(std::ostream& os, DataType dt) case FP32: return os << "FP32"; case BF16: return os << "BF16"; case FP8: return os << "FP8"; + case INT32: return os << "INT32"; case I8: return os << "I8"; case U8: return os << "U8"; + case UNDEFINDED: return os << "UNDEFINDED"; default: return os << "Unknown"; } } @@ -206,57 +233,16 @@ inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) } } -inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout) -{ - using enum GroupConvLayout1D; - switch(layout) - { - case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK"; - case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK"; - case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW"; - case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout) -{ - using enum GroupConvLayout2D; - switch(layout) - { - case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK"; - case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK"; - case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW"; - case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) -{ - using enum GroupConvLayout3D; - switch(layout) - { - case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK"; - case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK"; - case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW"; - case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW"; - default: return os << "Unknown"; - } -} - inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) { using enum ElementwiseOperation; switch(op) { - case BIAS: return os << "BIAS"; - case BIAS_CLAMP: return os << "BIAS_CLAMP"; - case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; - case BILINEAR: return os << "BILINEAR"; case CLAMP: return os << "CLAMP"; case SCALE: return os << "SCALE"; case PASS_THROUGH: return os << "PASS_THROUGH"; + case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; + case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU"; default: return os << "Unknown"; } } @@ -375,13 +361,59 @@ inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) } } -// ostream operator overload for std::variant of layout types -inline std::ostream& -operator<<(std::ostream& os, - const std::variant& layout) +inline std::ostream& operator<<(std::ostream& os, TensorLayout layout) { - std::visit([&os](const auto& l) { os << l; }, layout); - return os; + using enum TensorLayout; + switch(layout) + { + case GNCW: return os << "GNCW"; + case GNWC: return os << "GNWC"; + case NWGC: return os << "NWGC"; + case NGCW: return os << "NGCW"; + case G_NW_C_strided: return os << "G_NW_C_strided"; + case GNCHW: return os << "GNCHW"; + case GNHWC: return os << "GNHWC"; + case NHWGC: return os << "NHWGC"; + case NGCHW: return os << "NGCHW"; + case G_NHW_C_strided: return os << "G_NHW_C_strided"; + case GNCDHW: return os << "GNCDHW"; + case GNDHWC: return os << "GNDHWC"; + case NDHWGC: return os << "NDHWGC"; + case NGCDHW: return os << "NGCDHW"; + case G_NDHW_C_strided: return os << "G_NDHW_C_strided"; + case GKXC: return os << "GKXC"; + case GKCX: return os << "GKCX"; + case KXGC: return os << "KXGC"; + case G_K_X_C_strided: return os << "G_K_X_C_strided"; + case GKYXC: return os << "GKYXC"; + case GKCYX: return os << "GKCYX"; + case KYXGC: return os << "KYXGC"; + case G_K_YX_C_strided: return os << "G_K_YX_C_strided"; + case GKZYXC: return os << "GKZYXC"; + case GKCZYX: return os << "GKCZYX"; + case KZYXGC: return os << "KZYXGC"; + case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided"; + case GNKW: return os << "GNKW"; + case GNWK: return os << "GNWK"; + case NWGK: return os << "NWGK"; + case NGKW: return os << "NGKW"; + case G_NW_K_strided: return os << "G_NW_K_strided"; + case GNKHW: return os << "GNKHW"; + case GNHWK: return os << "GNHWK"; + case NHWGK: return os << "NHWGK"; + case NGKHW: return os << "NGKHW"; + case G_NHW_K_strided: return os << "G_NHW_K_strided"; + case GNKDHW: return os << "GNKDHW"; + case GNDHWK: return os << "GNDHWK"; + case NDHWGK: return os << "NDHWGK"; + case NGKDHW: return os << "NGKDHW"; + case G_NDHW_K_strided: return os << "G_NDHW_K_strided"; + case GC: return os << "GC"; + case G_C_strided: return os << "G_C_strided"; + case G_K_strided: return os << "G_K_strided"; + case UNDEFINED: return os << "UNDEFINED"; + default: return os << "Unknown"; + } } // ostream operator overload for std::variant of convolution specializations diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index e43c88c7a7..a340a789de 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -119,6 +119,7 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances + conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp conv/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 1cace0cf9a..937d17a1ff 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -13,11 +13,15 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE}; + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCW}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::NGKW}, + .operation = {.elementwise_operation = ElementwiseOperation::SCALE}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -30,10 +34,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 256, 256, 32", + "256,256,256,32", + "NGCW,GKXC,EmptyTuple,NGKW", + "PassThrough,PassThrough,Scale", "Filter1x1Stride1Pad0", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v2"}); + "MNKPadding", + "Intrawave", + "v2"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index 3315eb6f64..e8cd8fb136 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -10,14 +10,15 @@ using namespace ck_tile::builder::test_utils; // 1D FP16 (channels-last) with DEFAULT specialization TEST(FwdConvInstances, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale) + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NWGC}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::NWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -28,8 +29,12 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; - run_test( - {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"}); + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + "NWGC,GKXC,EmptyTuple,NWGK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding", + "64,64,32,32", + "Default"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index f6b18747b7..014e221101 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -14,12 +14,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, - .data_type = DataType::I8, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .data_type = DataType::I8, + .accumulation_data_type = DataType::INT32, + .input = {.config = {.layout = TensorLayout::GNWC}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} @@ -30,8 +31,11 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; - run_test( - {"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"}); + run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", + "128,64,64,64", + "GNWC,GKXC,EmptyTuple,GNWK", + "PassThrough,PassThrough,PassThrough", + "Default"}); } #endif diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index e0dc3225fa..b98e28c45a 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,22 +30,26 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 256, 256, 32", + "256,256,256,32", "Default", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v1"}); + "NHWGC,GKYXC,EmptyTuple,NHWGK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding", + "Intrawave", + "v1"}); } // 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -57,7 +62,10 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", "Filter3x3", - "BlkGemmPipelineVersion: v5"}); + "NHWGC,GKYXC,EmptyTuple,NHWGK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding", + "v5"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp new file mode 100644 index 0000000000..bc4a5e1047 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder; +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}}, + .output = ConvolutionTensor{ + .config = {.layout = TensorLayout::NHWGK}, + .operation = TensorOperation<>{.elementwise_operation = + ElementwiseOperation::SCALEADD_SCALEADD_RELU} + .with_auxiliary_operand_configs()}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave) + .with_transfer(FwdTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + "NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK", + "PassThrough,PassThrough,ScaleAddScaleAddRelu", + "64,64,32,32", + "MNKPadding", + "Default"}); +} + +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 4c4d128717..7af1448403 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -10,12 +10,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -26,19 +27,24 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; - run_test( - {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"}); + run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", + "256,128,128,16", + "Default", + "MNKPadding", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough"}); } TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -50,8 +56,12 @@ TEST(FwdConvInstances, .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; - run_test( - {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"}); + run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", + "256,128,128,16", + "Filter1x1Pad0", + "MNKPadding", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 36b44ffb41..7b522403d3 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,10 +30,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 256, 256, 32", + "256,256,256,32", "Filter1x1Pad0", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v3"}); + "Intrawave", + "v3", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index b2943d91b9..615d098c7c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, - .data_type = DataType::FP32, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP32, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCHW}}, + .weight = {.config = {.layout = TensorLayout::GKCYX}}, + .output = {.config = {.layout = TensorLayout::NGKHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,10 +30,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 128, 128, 32", + "256,128,128,32", "Filter1x1Stride1Pad0", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v4"}); + "Intrawave", + "v4", + "NGCHW,GKCYX,EmptyTuple,NGKHW", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp index d24df998fd..4dd9e2beef 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::FP8, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP8, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -28,8 +29,12 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; - run_test( - {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"}); + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + "256,256,128,32", + "Default", + "NHWGC,GKYXC,EmptyTuple,NHWGK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index be0ea3d0a5..8fe58dbe82 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -30,20 +31,24 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "256, 256, 128, 32", - "Default"}); + "256,256,128,32", + "Default", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } TEST( FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -57,8 +62,11 @@ TEST( using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "128, 128, 128, 32", - "Filter1x1Pad0"}); + "128,128,128,32", + "Filter1x1Pad0", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 0db89669f7..2df76ab3e0 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, - .data_type = DataType::BF16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNDHWC}}, + .weight = {.config = {.layout = TensorLayout::GKZYXC}}, + .output = {.config = {.layout = TensorLayout::GNDHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,10 +31,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 256, 256, 32", + "256,256,256,32", "Default", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v3"}); + "Intrawave", + "v3", + "GNDHWC,GKZYXC,EmptyTuple,GNDHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 80e12f9572..ad626d9a15 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, - .data_type = DataType::FP16, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKZYXC}}, + .output = {.config = {.layout = TensorLayout::NDHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -30,10 +32,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 128, 128, 32", + "256,128,128,32", "Filter1x1Pad0", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v4"}); + "Intrawave", + "v4", + "NDHWGC,GKZYXC,EmptyTuple,NDHWGK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index bfddd6efcb..85974ace5d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, - .data_type = DataType::FP32, - .elementwise_operation = - ElementwiseOperation::PASS_THROUGH}; + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP32, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCDHW}}, + .weight = {.config = {.layout = TensorLayout::GKCZYX}}, + .output = {.config = {.layout = TensorLayout::NGKDHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -30,10 +32,13 @@ TEST(FwdConvInstances, using Builder = ConvBuilder; run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256, 256, 256, 32", + "256,256,256,32", "Filter1x1Pad0", - "BlkGemmPipelineScheduler: Intrawave", - "BlkGemmPipelineVersion: v1"}); + "Intrawave", + "v1", + "NGCDHW,GKCZYX,EmptyTuple,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); } } // namespace diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp index 3684ae1c86..a6a7694703 100644 --- a/experimental/builder/test/conv/test_conv_traits.cpp +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -85,7 +85,10 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); - EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_THAT(Traits::layout, + ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, + ck_tile::builder::TensorLayout::GKYXC, + ck_tile::builder::TensorLayout::GNHWK)); EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); @@ -212,7 +215,10 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); - EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_THAT(Traits::layout, + ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, + ck_tile::builder::TensorLayout::GKYXC, + ck_tile::builder::TensorLayout::GNHWK)); EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); @@ -295,7 +301,10 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); - EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_THAT(Traits::layout, + ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, + ck_tile::builder::TensorLayout::GKYXC, + ck_tile::builder::TensorLayout::GNHWK)); EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index f18abb1c8d..ef87981c3d 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -10,14 +10,48 @@ namespace ck_tile::builder::test { using namespace ck_tile::builder; +struct TensorConfig +{ + TensorLayout layout; + // Optional data types, override the type defined in the signature if provided. + DataType data_type{DataType::UNDEFINDED}; + DataType compute_type{DataType::UNDEFINDED}; +}; + +template +struct TensorOperation +{ + ElementwiseOperation elementwise_operation{ElementwiseOperation::PASS_THROUGH}; + std::array auxiliary_operand_configs{Configs...}; + + // Add builder to add auxiliary tensor configs + template + constexpr auto with_auxiliary_operand_configs() const + { + return TensorOperation{ + .elementwise_operation = this->elementwise_operation}; + } +}; + +template > +struct ConvolutionTensor +{ + TensorConfig config; + Op operation{}; +}; + +template , + typename WeightTensor = ConvolutionTensor<>, + typename OutputTensor = ConvolutionTensor<>> struct ConvSignature { int spatial_dim; ConvDirection direction; - GroupConvLayout layout; DataType data_type; - ElementwiseOperation elementwise_operation; + DataType accumulation_data_type; + InputTensor input; + WeightTensor weight; + OutputTensor output; }; -static_assert(ConvSignatureDescriptor); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 5480c2740a..689577fb3b 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -16,40 +16,79 @@ namespace ckb = ck_tile::builder; namespace ckr = ck_tile::reflect; namespace ckt = ck_tile::test; +struct TensorOp +{ + ckb::ElementwiseOperation elementwise_operation{ckb::ElementwiseOperation::PASS_THROUGH}; +}; + +struct InvalidTensorOp +{ + int elementwise_operation = 7; // invalid value +}; +static_assert(!ckb::TensorOperatorDescriptor); + +struct TensorConfig +{ + ckb::TensorLayout layout; + ckb::DataType data_type{ckb::DataType::UNDEFINDED}; + ckb::DataType compute_type{ckb::DataType::UNDEFINDED}; +}; + +struct ConvTensorSimple +{ + TensorConfig config; +}; + +struct ConvTensorWithOp +{ + TensorConfig config; + TensorOp operation{}; +}; + +struct ConvTensorWithInvalidOp +{ + TensorConfig config; + InvalidTensorOp operation{}; +}; + // Defines the signature of the convolution operation to be tested. // This includes dimensionality, direction, data layout, and data type. struct ConvSignature { - int spatial_dim = 2; - ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; - ckb::DataType data_type = ckb::DataType::FP16; - // ckb::GroupConvDeviceOp device_operation = - // ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + int spatial_dim = 2; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::DataType accumulation_data_type = ckb::DataType::FP32; + ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}}; + ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}}; + ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}}; }; static_assert(ckb::ConvSignatureDescriptor); // Compile time tests for concepts struct ConvSignatureWithOptionalParams { - int spatial_dim = 2; - ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; - ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; + int spatial_dim = 2; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::DataType accumulation_data_type = ckb::DataType::FP32; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ConvTensorWithOp input = { + .config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16}, + }; + ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}}; + ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16}, + .operation = {ckb::ElementwiseOperation::SCALE}}; }; static_assert(ckb::ConvSignatureDescriptor); struct ConvSignatureWithInvalidOptionalParams { - int spatial_dim = 2; - ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; - ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; - ckb::DataType data_type = ckb::DataType::FP16; - int elementwise_operation = 7; // this should fail - // ckb::GroupConvDeviceOp device_operation = - // ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + int spatial_dim = 2; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::DataType accumulation_data_type = ckb::DataType::FP32; + ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}}; + ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}}; + ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}}; }; - static_assert(!ckb::ConvSignatureDescriptor); struct DefaultAlgorithm @@ -123,7 +162,9 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) "2D Forward Convolution Kernel\n" "├─ Signature\n" "│ ├─ Tensor Type: FP16\n" - "│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n" + "│ ├─ Input Layout: GNHWC\n" + "│ ├─ Weight Layout: GKYXC\n" + "│ ├─ Output Layout: GNHWK\n" "│ ├─ Input elementwise operation: PASS_THROUGH\n" "│ ├─ Weights elementwise operation: PASS_THROUGH\n" "│ └─ Output elementwise operation: PASS_THROUGH\n" diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 66593bf802..84a9c533f6 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -8,30 +8,38 @@ namespace { -using ::ck_tile::builder::factory::internal::ElementwiseOps; -using enum ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::factory::internal::ElementwiseOpToCK; TEST(ConvElementwiseOp, AssignsOpsForPassThrough) { - using Ops = ElementwiseOps; - - EXPECT_TRUE( - (std::is_same_v)); - EXPECT_TRUE( - (std::is_same_v)); - EXPECT_TRUE( - (std::is_same_v)); + using Op = ElementwiseOpToCK::Op; + EXPECT_TRUE((std::is_same_v)); } TEST(ConvElementwiseOp, AssignsOpsForScale) { - using Ops = ElementwiseOps; + using Op = ElementwiseOpToCK::Op; + EXPECT_TRUE((std::is_same_v)); +} +TEST(ConvElementwiseOp, AssignsOpsForClamp) +{ + using Op = ElementwiseOpToCK::Op; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvElementwiseOp, AssignsOpsForScaleAddScaleAddRelu) +{ + using Op = ElementwiseOpToCK::Op; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvElementwiseOp, AssignsOpsForBiasNormClamp) +{ + using Op = ElementwiseOpToCK::Op; EXPECT_TRUE( - (std::is_same_v)); - EXPECT_TRUE( - (std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + (std::is_same_v)); } } // namespace diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 6cdcc429dd..7764e94dc6 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,116 +4,481 @@ #include #include -// Include the helper file we're testing #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "impl/conv_signature_types.hpp" namespace { namespace ckb = ::ck_tile::builder; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::TensorLayout; +using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts; using ::ck_tile::builder::factory::internal::ConvTensorLayouts; -using ::ck_tile::builder::factory::internal::GetTensorLayout; +using ::ck_tile::builder::factory::internal::LayoutToCK; + +using namespace ::ck_tile::builder::test; using enum ::ck_tile::builder::ConvDirection; TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NWGC}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::NWGK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCW}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::NGKW}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNWC}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::GNWK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCW}}, + .weight = {.config = {.layout = TensorLayout::GKCX}}, + .output = {.config = {.layout = TensorLayout::NGKW}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCHW}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NGKHW}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) { - using TensorLayouts = ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCHW}}, + .weight = {.config = {.layout = TensorLayout::GKCYX}}, + .output = {.config = {.layout = TensorLayout::NGKHW}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) { - using TensorLayouts = - ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCDHW}}, + .weight = {.config = {.layout = TensorLayout::GKCZYX}}, + .output = {.config = {.layout = TensorLayout::NGKDHW}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) { - using TensorLayouts = - ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKZYXC}}, + .output = {.config = {.layout = TensorLayout::NDHWGK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) { - using TensorLayouts = - ConvTensorLayouts; + static constexpr auto sig = + ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNDHWC}}, + .weight = {.config = {.layout = TensorLayout::GKZYXC}}, + .output = {.config = {.layout = TensorLayout::GNDHWK}}}; + + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v>)); +} + +TEST(AuxiliaryTensorLayout, AssignsLayoutForG_K_strided) +{ + using CKLayout = LayoutToCK::type; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayout, AssignsLayoutForGC) +{ + using CKLayout = LayoutToCK::type; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayout, AssignsLayoutForG_C_strided) +{ + using CKLayout = LayoutToCK::type; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayout, EmptyAuxiliaryTensorLayoutIsEmptyTuple) +{ + using ::ck_tile::builder::factory::internal::EmptyAuxiliaryTensorLayout; + using EmptyLayout = EmptyAuxiliaryTensorLayout::type; + EXPECT_TRUE((std::is_same_v>)); +} + +struct MockAuxiliaryTensorConfig +{ + TensorLayout layout; +}; + +TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 1); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 1); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 1); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, + MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 2); + using ExpectedType = + ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, + MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}, + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 3); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 1); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) +{ + static constexpr std::array aux_configs = { + MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + + using AuxLayouts = AuxiliaryTensorLayouts; + + EXPECT_EQ(AuxLayouts::Size, 1); + using ExpectedType = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) +{ + using OutputOp = TensorOperation; + + static constexpr auto sig = + ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ + .spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NGCHW}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NGKHW}, + .operation = + OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; + + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + + using ExpectedDsLayout = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) +{ + using OutputOp = TensorOperation; + + static constexpr auto sig = + ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ + .spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::BF16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}, + .operation = + OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; + + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + + using ExpectedDsLayout = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) +{ + using OutputOp = TensorOperation; + + static constexpr auto sig = + ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ + .spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::GNHWC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::GNHWK}, + .operation = OutputOp{.elementwise_operation = + ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; + + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + + using ExpectedDsLayout = + ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) +{ + using OutputOp = TensorOperation; + + static constexpr auto sig = + ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ + .spatial_dim = 1, + .direction = FORWARD, + .data_type = DataType::FP32, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NWGC}}, + .weight = {.config = {.layout = TensorLayout::GKXC}}, + .output = {.config = {.layout = TensorLayout::NWGK}, + .operation = + OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; + + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + + using ExpectedDsLayout = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) +{ + using OutputOp = TensorOperation; + + static constexpr auto sig = + ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ + .spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKZYXC}}, + .output = {.config = {.layout = TensorLayout::NDHWGK}, + .operation = OutputOp{.elementwise_operation = + ElementwiseOperation::BIAS_BNORM_CLAMP}}}; + + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + + using ExpectedDsLayout = ck::Tuple; + EXPECT_TRUE((std::is_same_v)); } } // namespace diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index 5aa82774da..c92b24626e 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -9,71 +9,42 @@ namespace { namespace ckb = ck_tile::builder; -using ck_tile::builder::factory::internal::ConvTensorTypes; +using ck_tile::builder::factory::internal::DataTypeToCK; TEST(ConvTensorType, AssignsTypesForFP16) { - using Types = ConvTensorTypes; - - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorType, AssignsTypesForBF16) { - using Types = ConvTensorTypes; - - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorType, AssignsTypesForFP32) { - using Types = ConvTensorTypes; + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); +} - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); +TEST(ConvTensorType, AssignsTypesForINT32) +{ + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorType, AssignsTypesForI8) { - using Types = ConvTensorTypes; - - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorType, AssignsTypesForFP8) { - using Types = ConvTensorTypes; - - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + using CKType = DataTypeToCK::type; + EXPECT_TRUE((std::is_same_v)); } } // namespace diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 5436755608..403c2ffd79 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -178,6 +178,9 @@ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}; + constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index f3db734da8..508c621c2e 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -15,7 +15,7 @@ constexpr void run_test(const std::vector& kernel_instance_componen { auto instance = typename Builder::Instance{}; - const auto kernel_string = instance.GetTypeString(); + const auto kernel_string = instance.GetInstanceString(); std::cout << "Generated kernel: " << kernel_string << std::endl; EXPECT_GT(kernel_string.size(), 0); From cd21e20ae7d4d3a6309ce238bb94814e145585d6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Dec 2025 06:58:42 -0800 Subject: [PATCH 05/65] build latest hipblaslt in ck_pytorch docker (#3347) --- Dockerfile.pytorch | 11 ++++++++++- Jenkinsfile | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 1b71b00fbb..4533166c06 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -20,4 +20,13 @@ RUN groupadd -g 109 render && \ git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \ chown -R jenkins:jenkins /tmp/pytorch && \ chmod -R a+rwx /tmp/pytorch && \ - sudo usermod -aG irc jenkins + sudo usermod -aG irc jenkins && \ + #install hipblaslt + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \ + cd rocm-libraries && \ + git checkout develop && \ + git sparse-checkout init --cone && \ + git sparse-checkout set projects/hipblaslt shared/origami && \ + cd projects/hipblaslt && \ + git show --oneline -s && \ + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller diff --git a/Jenkinsfile b/Jenkinsfile index a2e5b3d20b..45fd576ab6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -447,7 +447,7 @@ def get_docker_options(){ dockerOpts = "--network=host --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" } else{ //only add kfd and dri paths if you actually going to run somthing on GPUs - dockerOpts = "--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + dockerOpts = "--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" } if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with @@ -1003,7 +1003,7 @@ def run_aiter_tests(Map conf=[:]){ checkout scm //use the latest pytorch image def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" - def dockerOpts=get_docker_options() + def dockerOpts=get_docker_options() + ' --group-add irc ' gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'composable_kernel') { try @@ -1055,7 +1055,7 @@ def run_pytorch_tests(Map conf=[:]){ checkout scm //use the latest pytorch-nightly image def image = "${env.CK_DOCKERHUB}:ck_pytorch" - def dockerOpts=get_docker_options() + def dockerOpts=get_docker_options() + ' --group-add irc ' gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'composable_kernel') { try From d9d4c9c3dfe38fe54bae5b3b1b9b523b011992dd Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Thu, 4 Dec 2025 14:09:21 -0500 Subject: [PATCH 06/65] [composable_kernel] initial draft of the ck tile conceptual doc (#3242) * Adding CK Tile documentation * Updates based on feedback * Fix tile window API description * Fix remaining images * add documentation about flush_cache and rotating_buffer functionality in ck_tile * Supplement the documentation * light edit of the ck tile conceptual doc * Fixes for ruff check. * Fixes for ruff check 2. * Fixes for ruff check 3. --------- Co-authored-by: Vidyasagar Co-authored-by: AviralGoelAMD Co-authored-by: ThomasNing Co-authored-by: Vidyasagar Ananthan --- docs/conceptual/ck_tile/CK-tile-index.rst | 33 + docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md | 156 ++++ docs/conceptual/ck_tile/adaptors.rst | 391 +++++++++ docs/conceptual/ck_tile/buffer_views.rst | 443 ++++++++++ .../ck_tile/cache_flushing_benchmarking.rst | 390 +++++++++ .../ck_tile/convert_mermaid_to_svg.py | 224 +++++ .../ck_tile/convert_raw_html_to_commented.py | 84 ++ .../ck_tile/convolution_example.rst | 567 +++++++++++++ .../ck_tile/coordinate_movement.rst | 532 ++++++++++++ .../conceptual/ck_tile/coordinate_systems.rst | 612 ++++++++++++++ docs/conceptual/ck_tile/descriptors.rst | 383 +++++++++ .../ck_tile/diagrams/adaptors_1.svg | 1 + .../ck_tile/diagrams/adaptors_2.svg | 1 + .../ck_tile/diagrams/buffer_views_1.svg | 1 + .../ck_tile/diagrams/buffer_views_2.svg | 1 + .../ck_tile/diagrams/buffer_views_3.svg | 1 + .../ck_tile/diagrams/buffer_views_4.svg | 1 + .../ck_tile/diagrams/convolution_example.svg | 1 + .../ck_tile/diagrams/coordinate_movement.svg | 1 + .../ck_tile/diagrams/coordinate_systems_1.svg | 1 + .../ck_tile/diagrams/coordinate_systems_2.svg | 1 + .../ck_tile/diagrams/coordinate_systems_3.svg | 1 + .../ck_tile/diagrams/coordinate_systems_4.svg | 1 + .../ck_tile/diagrams/coordinate_systems_5.svg | 1 + .../ck_tile/diagrams/coordinate_systems_6.svg | 1 + .../ck_tile/diagrams/descriptors_1.svg | 1 + .../ck_tile/diagrams/descriptors_2.svg | 1 + .../ck_tile/diagrams/encoding_internals_1.svg | 1 + .../ck_tile/diagrams/encoding_internals_2.svg | 1 + .../diagrams/introduction_motivation_1.svg | 1 + .../diagrams/introduction_motivation_2.svg | 1 + .../ck_tile/diagrams/lds_index_swapping_1.svg | 1 + .../ck_tile/diagrams/lds_index_swapping_2.svg | 1 + .../ck_tile/diagrams/lds_index_swapping_3.svg | 1 + .../ck_tile/diagrams/load_store_traits_1.svg | 1 + .../ck_tile/diagrams/load_store_traits_2.svg | 1 + .../ck_tile/diagrams/space_filling_curve.svg | 1 + .../diagrams/static_distributed_tensor.svg | 1 + .../ck_tile/diagrams/sweep_tile_1.svg | 1 + .../ck_tile/diagrams/sweep_tile_2.svg | 1 + .../ck_tile/diagrams/sweep_tile_3.svg | 1 + .../ck_tile/diagrams/sweep_tile_4.svg | 1 + .../ck_tile/diagrams/tensor_coordinates_1.svg | 1 + .../ck_tile/diagrams/tensor_coordinates_2.svg | 1 + .../ck_tile/diagrams/tensor_views_1.svg | 1 + .../ck_tile/diagrams/tensor_views_2.svg | 1 + .../ck_tile/diagrams/tensor_views_3.svg | 1 + .../ck_tile/diagrams/tensor_views_4.svg | 1 + .../ck_tile/diagrams/tensor_views_5.svg | 1 + .../ck_tile/diagrams/thread_mapping_1.svg | 1 + .../ck_tile/diagrams/thread_mapping_2.svg | 1 + .../ck_tile/diagrams/tile_distribution_1.svg | 1 + .../ck_tile/diagrams/tile_distribution_2.svg | 1 + .../ck_tile/diagrams/tile_distribution_3.svg | 1 + .../ck_tile/diagrams/tile_distribution_4.svg | 1 + .../ck_tile/diagrams/tile_distribution_5.svg | 1 + .../ck_tile/diagrams/tile_distribution_6.svg | 1 + .../ck_tile/diagrams/tile_distribution_7.svg | 1 + .../ck_tile/diagrams/tile_window_1.svg | 1 + .../ck_tile/diagrams/tile_window_2.svg | 1 + .../ck_tile/diagrams/tile_window_3.svg | 1 + .../ck_tile/diagrams/tile_window_4.svg | 1 + .../ck_tile/diagrams/tile_window_5.svg | 1 + .../ck_tile/diagrams/transforms_1.svg | 1 + .../ck_tile/diagrams/transforms_10.svg | 1 + .../ck_tile/diagrams/transforms_11.svg | 1 + .../ck_tile/diagrams/transforms_12.svg | 1 + .../ck_tile/diagrams/transforms_2.svg | 1 + .../ck_tile/diagrams/transforms_3.svg | 1 + .../ck_tile/diagrams/transforms_4.svg | 1 + .../ck_tile/diagrams/transforms_5.svg | 1 + .../ck_tile/diagrams/transforms_6.svg | 1 + .../ck_tile/diagrams/transforms_7.svg | 1 + .../ck_tile/diagrams/transforms_8.svg | 1 + .../ck_tile/diagrams/transforms_9.svg | 1 + .../conceptual/ck_tile/encoding_internals.rst | 489 +++++++++++ .../ck_tile/hardware/gemm_optimization.rst | 385 +++++++++ .../ck_tile/hardware/gpu_basics.rst | 38 + docs/conceptual/ck_tile/hardware/index.rst | 127 +++ .../ck_tile/hardware/lds_bank_conflicts.rst | 209 +++++ docs/conceptual/ck_tile/index.rst | 108 +++ .../ck_tile/introduction_motivation.rst | 309 +++++++ .../conceptual/ck_tile/lds_index_swapping.rst | 462 +++++++++++ docs/conceptual/ck_tile/load_store_traits.rst | 480 +++++++++++ .../ck_tile/space_filling_curve.rst | 511 ++++++++++++ .../ck_tile/static_distributed_tensor.rst | 429 ++++++++++ docs/conceptual/ck_tile/sweep_tile.rst | 560 +++++++++++++ docs/conceptual/ck_tile/swizzling_example.rst | 495 +++++++++++ .../conceptual/ck_tile/tensor_coordinates.rst | 459 +++++++++++ docs/conceptual/ck_tile/tensor_views.rst | 482 +++++++++++ docs/conceptual/ck_tile/terminology.rst | 383 +++++++++ docs/conceptual/ck_tile/thread_mapping.rst | 551 +++++++++++++ docs/conceptual/ck_tile/tile_distribution.rst | 627 ++++++++++++++ docs/conceptual/ck_tile/tile_window.rst | 701 ++++++++++++++++ docs/conceptual/ck_tile/transforms.rst | 769 ++++++++++++++++++ docs/conceptual/ck_tile/update_diagrams.py | 215 +++++ docs/index.rst | 1 + docs/sphinx/_toc.yml.in | 2 + 98 files changed, 12671 insertions(+) create mode 100644 docs/conceptual/ck_tile/CK-tile-index.rst create mode 100644 docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md create mode 100644 docs/conceptual/ck_tile/adaptors.rst create mode 100644 docs/conceptual/ck_tile/buffer_views.rst create mode 100644 docs/conceptual/ck_tile/cache_flushing_benchmarking.rst create mode 100644 docs/conceptual/ck_tile/convert_mermaid_to_svg.py create mode 100644 docs/conceptual/ck_tile/convert_raw_html_to_commented.py create mode 100644 docs/conceptual/ck_tile/convolution_example.rst create mode 100644 docs/conceptual/ck_tile/coordinate_movement.rst create mode 100644 docs/conceptual/ck_tile/coordinate_systems.rst create mode 100644 docs/conceptual/ck_tile/descriptors.rst create mode 100644 docs/conceptual/ck_tile/diagrams/adaptors_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/adaptors_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/buffer_views_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/buffer_views_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/buffer_views_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/buffer_views_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/convolution_example.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_movement.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_5.svg create mode 100644 docs/conceptual/ck_tile/diagrams/coordinate_systems_6.svg create mode 100644 docs/conceptual/ck_tile/diagrams/descriptors_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/descriptors_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/encoding_internals_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/encoding_internals_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/introduction_motivation_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/introduction_motivation_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/lds_index_swapping_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/lds_index_swapping_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/lds_index_swapping_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/load_store_traits_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/load_store_traits_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/space_filling_curve.svg create mode 100644 docs/conceptual/ck_tile/diagrams/static_distributed_tensor.svg create mode 100644 docs/conceptual/ck_tile/diagrams/sweep_tile_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/sweep_tile_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/sweep_tile_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/sweep_tile_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_coordinates_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_coordinates_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_views_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_views_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_views_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_views_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tensor_views_5.svg create mode 100644 docs/conceptual/ck_tile/diagrams/thread_mapping_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/thread_mapping_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_5.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_6.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_distribution_7.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_window_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_window_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_window_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_window_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/tile_window_5.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_1.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_10.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_11.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_12.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_2.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_3.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_4.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_5.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_6.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_7.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_8.svg create mode 100644 docs/conceptual/ck_tile/diagrams/transforms_9.svg create mode 100644 docs/conceptual/ck_tile/encoding_internals.rst create mode 100644 docs/conceptual/ck_tile/hardware/gemm_optimization.rst create mode 100644 docs/conceptual/ck_tile/hardware/gpu_basics.rst create mode 100644 docs/conceptual/ck_tile/hardware/index.rst create mode 100644 docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst create mode 100644 docs/conceptual/ck_tile/index.rst create mode 100644 docs/conceptual/ck_tile/introduction_motivation.rst create mode 100644 docs/conceptual/ck_tile/lds_index_swapping.rst create mode 100644 docs/conceptual/ck_tile/load_store_traits.rst create mode 100644 docs/conceptual/ck_tile/space_filling_curve.rst create mode 100644 docs/conceptual/ck_tile/static_distributed_tensor.rst create mode 100644 docs/conceptual/ck_tile/sweep_tile.rst create mode 100644 docs/conceptual/ck_tile/swizzling_example.rst create mode 100644 docs/conceptual/ck_tile/tensor_coordinates.rst create mode 100644 docs/conceptual/ck_tile/tensor_views.rst create mode 100644 docs/conceptual/ck_tile/terminology.rst create mode 100644 docs/conceptual/ck_tile/thread_mapping.rst create mode 100644 docs/conceptual/ck_tile/tile_distribution.rst create mode 100644 docs/conceptual/ck_tile/tile_window.rst create mode 100644 docs/conceptual/ck_tile/transforms.rst create mode 100644 docs/conceptual/ck_tile/update_diagrams.py diff --git a/docs/conceptual/ck_tile/CK-tile-index.rst b/docs/conceptual/ck_tile/CK-tile-index.rst new file mode 100644 index 0000000000..e18cb24f80 --- /dev/null +++ b/docs/conceptual/ck_tile/CK-tile-index.rst @@ -0,0 +1,33 @@ +.. _ck_tile_index: + +************************ +CK Tile Index +************************ + +CK Tile documentation structure: + +.. toctree:: + :maxdepth: 2 + + introduction_motivation + buffer_views + tensor_views + tile_distribution + coordinate_systems + terminology + adaptors + transforms + descriptors + tile_window + load_store_traits + space_filling_curve + static_distributed_tensor + convolution_example + coordinate_movement + lds_index_swapping + swizzling_example + tensor_coordinates + sweep_tile + encoding_internals + thread_mapping + hardware/index diff --git a/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md b/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md new file mode 100644 index 0000000000..5e8679dbd2 --- /dev/null +++ b/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md @@ -0,0 +1,156 @@ +# Mermaid Diagram Management + +This document explains how to manage mermaid diagrams in the CK Tile documentation. + +## Overview + +All mermaid diagrams in the CK Tile documentation have been converted to SVG files for better rendering compatibility. The original mermaid source code is preserved as commented blocks in the RST files, allowing easy updates when needed. + +## Directory Structure + +- `docs/conceptual/ck_tile/diagrams/` - Contains all SVG diagram files +- `docs/conceptual/ck_tile/convert_mermaid_to_svg.py` - Initial conversion script (one-time use) +- `docs/conceptual/ck_tile/update_diagrams.py` - Helper script to regenerate diagrams from comments + +## Diagram Format in RST Files + +Each diagram follows this format: + +```rst +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + A --> B + B --> C + +.. image:: diagrams/diagram_name.svg + :alt: Diagram + :align: center +``` + +The commented mermaid block won't appear in the rendered documentation but serves as the source for regenerating the SVG. + +## Updating Diagrams + +### When to Update + +You need to regenerate SVG files when: +- Modifying the mermaid source in a commented block +- Adding new diagrams +- Updating diagram styling + +### How to Update + +1. **Edit the commented mermaid source** in the RST file +2. **Run the update script**: + ```bash + # Update all diagrams + python docs/conceptual/ck_tile/update_diagrams.py + + # Update diagrams in a specific file + python docs/conceptual/ck_tile/update_diagrams.py transforms.rst + + # Force regenerate all diagrams (even if SVGs exist) + python docs/conceptual/ck_tile/update_diagrams.py --force + ``` + +### Prerequisites + +The update script requires [mermaid-cli](https://github.com/mermaid-js/mermaid-cli): + +```bash +npm install -g @mermaid-js/mermaid-cli +``` + +## Adding New Diagrams + +To add a new mermaid diagram: + +1. **Create the commented block** in your RST file: + ```rst + .. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + A --> B + ``` + +2. **Add the image reference** immediately after: + ```rst + .. image:: diagrams/my_new_diagram.svg + :alt: My New Diagram + :align: center + ``` + +3. **Generate the SVG**: + ```bash + python docs/conceptual/ck_tile/update_diagrams.py your_file.rst + ``` + +## Current Diagrams + +The following RST files contain mermaid diagrams (40 total): + +- `adaptors.rst` (2 diagrams) +- `convolution_example.rst` (1 diagram) +- `coordinate_movement.rst` (1 diagram) +- `descriptors.rst` (2 diagrams) +- `encoding_internals.rst` (2 diagrams) +- `lds_index_swapping.rst` (3 diagrams) +- `load_store_traits.rst` (2 diagrams) +- `space_filling_curve.rst` (1 diagram) +- `static_distributed_tensor.rst` (1 diagram) +- `sweep_tile.rst` (4 diagrams) +- `tensor_coordinates.rst` (2 diagrams) +- `thread_mapping.rst` (2 diagrams) +- `tile_window.rst` (5 diagrams) +- `transforms.rst` (12 diagrams) + +## Troubleshooting + +### SVG not generated + +- Check that mermaid-cli is installed: `mmdc --version` +- Verify the mermaid syntax is valid +- Look for error messages in the script output + +### Diagram not updating + +- Use `--force` flag to regenerate: `python docs/update_diagrams.py --force` +- Check that the image reference matches the generated filename + +### Pattern not matching + +If the update script can't find your commented diagram: +- Ensure proper indentation (3 spaces for comment block content) +- Verify the `.. mermaid::` directive is commented +- Check that the image reference immediately follows the comment block + +## Script Details + +### update_diagrams.py + +This script: +1. Scans RST files for commented mermaid blocks +2. Extracts the mermaid source code +3. Converts to SVG using `mmdc` +4. Saves to the diagrams directory + +**Usage:** +- `python docs/conceptual/ck_tile/update_diagrams.py` - Check all files, update missing SVGs +- `python docs/conceptual/ck_tile/update_diagrams.py --force` - Regenerate all SVGs +- `python docs/conceptual/ck_tile/update_diagrams.py ` - Update specific file + +### convert_mermaid_to_svg.py + +This was the initial conversion script. It: +1. Found all active `.. mermaid::` directives +2. Converted them to SVGs +3. Replaced directives with commented source + image references + +This script was used once for the initial conversion and typically doesn't need to be run again. diff --git a/docs/conceptual/ck_tile/adaptors.rst b/docs/conceptual/ck_tile/adaptors.rst new file mode 100644 index 0000000000..9e8907ab10 --- /dev/null +++ b/docs/conceptual/ck_tile/adaptors.rst @@ -0,0 +1,391 @@ +.. _ck_tile_adaptors: + +Tensor Adaptors - Chaining Transformations +========================================== + +Overview +-------- + +While individual :ref:`transforms ` are effective, TensorAdaptors enable the chaining of multiple transforms together to create complex coordinate transformations. Adaptors can be thought of as transformation pipelines that can reshape, reorder, and restructure tensors in advanced ways. + +TensorAdaptors serve as the bridge between individual transforms and the high-level tensor operations used in applications. They provide a composable abstraction that allows developers to build complex data access patterns from simple building blocks. + +TensorAdaptor Basics +-------------------- + +A TensorAdaptor encapsulates a sequence of :ref:`coordinate transformations `, managing the flow of coordinates through multiple transform stages: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Adaptor Composition" + subgraph "Single Transform" + direction TB + I1["Input Coords
[0,1,2]"] + T1["Transform
(e.g., Transpose)"] + O1["Output Coords
[2,0,1]"] + I1 --> T1 --> O1 + end + + subgraph "Chained Transforms" + direction TB + I2["Input
2D"] + T2A["Transform A
(e.g., Merge)"] + M2["Intermediate
1D"] + T2B["Transform B
(e.g., Pad)"] + O2["Output
1D Padded"] + I2 --> T2A --> M2 --> T2B --> O2 + end + end + + style T1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style T2A fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style T2B fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + + +.. image:: diagrams/adaptors_1.svg + :alt: Diagram + :align: center + +.. image:: diagrams/adaptors_1.svg + :alt: Diagram + :align: center +Core Components + +~~~~~~~~~~~~~~~ + +Each TensorAdaptor contains: + +- **transforms**: List of individual :ref:`transforms ` to apply +- **lower_dimension_hidden_idss**: Mappings between transform stages +- **upper_dimension_hidden_idss**: Hidden dimension mappings for internal stages +- **bottom_dimension_hidden_ids**: Input dimension identifiers +- **top_dimension_hidden_ids**: Output dimension identifiers + +The most important method of a TensorAdaptor is ``calculate_bottom_index``, which calculates the lower index from the upper index by applying transforms in reverse order. + +Transpose Adaptor: Dimension Reordering +--------------------------------------- + +The transpose adaptor reorders tensor dimensions according to a permutation pattern. This operation forms the basis for many tensor manipulations in GPU kernels. + +.. code-block:: cpp + + // Create transpose adaptor: [0, 1, 2] → [2, 0, 1] + auto transpose_adaptor = make_identity_tensor_adaptor<3>(); // Start with identity + + // Apply transpose using transform_tensor_adaptor + auto transposed_desc = transform_tensor_descriptor( + original_desc, + make_tuple(make_pass_through_transform(original_desc.get_length(2)), + make_pass_through_transform(original_desc.get_length(0)), + make_pass_through_transform(original_desc.get_length(1))), + make_tuple(sequence<2>{}, sequence<0>{}, sequence<1>{}), // old dims + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}) // new dims + ); + + // Alternative: Direct coordinate transformation + multi_index<3> top_coord{0, 1, 2}; + // After transpose [2, 0, 1]: coord becomes [2, 0, 1] + +Single-Stage Adaptors: Custom Transform Chains +---------------------------------------------- + +Custom adaptors can be created by specifying which transforms to use and how they connect. This provides fine-grained control over the transformation pipeline: + +.. code-block:: cpp + + // Create a descriptor that merges 2x3 dimensions into single dimension + auto base_desc = make_naive_tensor_descriptor_packed(make_tuple(2, 3)); + + // Apply merge transform + auto merged_desc = transform_tensor_descriptor( + base_desc, + make_tuple(make_merge_transform(make_tuple(2, 3))), + make_tuple(sequence<0, 1>{}), // merge dims 0,1 + make_tuple(sequence<0>{}) // to single dim 0 + ); + + // The adaptor is embedded in the :ref:`descriptor ` + // To use it: + multi_index<1> top_coord{5}; // 1D coordinate + // This internally calculates: row = 5/3 = 1, col = 5%3 = 2 + +Chaining Adaptors: Building Complex Transformations +--------------------------------------------------- + +The real power of adaptors comes from chaining multiple transformations together to create advanced data access patterns: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Adaptor Chaining Flow" + subgraph "Adaptor 1" + A1I["Bottom Dims
[0,1]"] + A1T["Transform:
Merge[2,3]"] + A1O["Top Dims
[0]"] + end + + subgraph "Adaptor 2" + A2I["Bottom Dims
[0]"] + A2T["Transform:
Unmerge[2,3]"] + A2O["Top Dims
[0,1]"] + end + + subgraph "Chained Result" + CI["Input 2D
Bottom[0,1]"] + CO["Output 2D
Top[0,1]"] + end + end + + A1I --> A1T + A1T --> A1O + A1O --> A2I + A2I --> A2T + A2T --> A2O + + CI --> A1I + A2O --> CO + + style A1T fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style A2T fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style CI fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style CO fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + +.. image:: diagrams/adaptors_2.svg + :alt: Diagram + :align: center + +.. image:: diagrams/adaptors_2.svg + :alt: Diagram + :align: center + +.. code-block:: cpp + + // Start with a 2D descriptor + auto desc1 = make_naive_tensor_descriptor_packed(make_tuple(2, 3)); + + // First transformation: merge 2D to 1D + auto merged_desc = transform_tensor_descriptor( + desc1, + make_tuple(make_merge_transform(make_tuple(2, 3))), + make_tuple(sequence<0, 1>{}), // merge dims 0,1 + make_tuple(sequence<0>{}) // to dim 0 + ); + + // Second transformation: unmerge 1D back to 2D + auto final_desc = transform_tensor_descriptor( + merged_desc, + make_tuple(make_unmerge_transform(make_tuple(2, 3))), + make_tuple(sequence<0>{}), // from dim 0 + make_tuple(sequence<0, 1>{}) // to dims 0,1 + ); + + // The chained transformation is embedded in final_desc + // Result should be identity transformation + +Transform Addition: Extending Existing Adaptors +----------------------------------------------- + +Existing adaptors can be extended with new transforms using ``transform_tensor_adaptor``. This pattern is useful for adding padding or other modifications to existing transformation pipelines: + +.. code-block:: cpp + + // Start with transposed descriptor + auto base_desc = make_naive_tensor_descriptor( + make_tuple(3, 4), + make_tuple(1, 3) // transposed strides + ); + + // Add padding to both dimensions + auto padded_desc = transform_tensor_descriptor( + base_desc, + make_tuple(make_pad_transform(3, 1, 1), // pad dim 0: 3 → 5 + make_pad_transform(4, 0, 0)), // keep dim 1: 4 → 4 + make_tuple(sequence<0>{}, sequence<1>{}), // input dims + make_tuple(sequence<0>{}, sequence<1>{}) // output dims (keep 2D) + ); + + // Access pattern + multi_index<2> padded_coord{1, 2}; // In padded space + // Internally calculates: unpadded = [1-1, 2] = [0, 2] + // Then applies transpose strides + +Advanced Patterns +----------------- + +Complex Nested Transforms +~~~~~~~~~~~~~~~~~~~~~~~~~ + +CK Tile supports complex nested transform patterns that enable advanced data layouts: + +.. code-block:: cpp + + // Example: 4D tensor with complex transformations + // Shape: [A, B, C, D] with various transforms + + // 1. Create base descriptor + auto base_desc = make_naive_tensor_descriptor_packed( + make_tuple(A, B, C, D) + ); + + // 2. Apply multiple transformations + // First: merge first 3 dimensions + auto step1_desc = transform_tensor_descriptor( + base_desc, + make_tuple(make_merge_transform(make_tuple(A, B, C)), + make_pass_through_transform(D)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), // input mapping + make_tuple(sequence<0>{}, sequence<1>{}) // output: 2D + ); + + // 3. Then unmerge back but with different grouping + auto step2_desc = transform_tensor_descriptor( + step1_desc, + make_tuple(make_unmerge_transform(make_tuple(A*B, C)), + make_pass_through_transform(D)), + make_tuple(sequence<0>{}, sequence<1>{}), // from 2D + make_tuple(sequence<0, 1>{}, sequence<2>{}) // to 3D + ); + + // The adaptor chain is embedded in the descriptors + // CK optimizes these at compile time + +GPU Memory Layout Example +~~~~~~~~~~~~~~~~~~~~~~~~~ + +A practical example showing how adaptors create efficient :ref:`GPU memory access patterns `: + +.. code-block:: cpp + + // Create descriptor for thread block tile: 64x64 + // With 8x8 vector loads per thread + constexpr auto BlockM = 64; + constexpr auto BlockN = 64; + constexpr auto VectorM = 8; + constexpr auto VectorN = 8; + + // Thread arrangement: 8x8 threads + constexpr auto ThreadM = BlockM / VectorM; // 8 + constexpr auto ThreadN = BlockN / VectorN; // 8 + + // Create block descriptor with proper layout + auto block_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}) + ), + make_tuple( + make_unmerge_transform(make_tuple( + number{}, number{} + )), + make_unmerge_transform(make_tuple( + number{}, number{} + )) + ), + make_tuple(sequence<0>{}, sequence<1>{}), // from 2D + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}) // to 4D: [TM,TN,VM,VN] + ); + + // This creates the layout: + // - Dimension 0,1: Thread indices + // - Dimension 2,3: Vector indices within thread + // Enables coalesced memory access on GPU + // See :ref:`ck_tile_thread_mapping` for thread mapping details + +Common Transform Chains +----------------------- + +CK Tile provides several common transform chain patterns used throughout GPU kernels: + +**Padding for Convolution** + +.. code-block:: cpp + + auto padded = transform_tensor_descriptor( + input, + make_tuple(make_pad_transform(H, pad_h, pad_h), + make_pad_transform(W, pad_w, pad_w)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + +**Dimension Merging for GEMM** + +.. code-block:: cpp + + auto merged = transform_tensor_descriptor( + input, + make_tuple(make_merge_transform(make_tuple(M, K))), + make_tuple(sequence<0, 1>{}), + make_tuple(sequence<0>{}) + ); + +For complete GEMM optimization strategies, see :ref:`ck_tile_gemm_optimization`. + +**Broadcasting for Elementwise Operations** + +.. code-block:: cpp + + auto broadcast = transform_tensor_descriptor( + scalar, + make_tuple(make_replicate_transform(make_tuple(M, N))), + make_tuple(sequence<>{}), + make_tuple(sequence<0, 1>{}) + ); + +Key Concepts Summary +-------------------- + +TensorAdaptors are the coordination layer that makes complex tensor operations possible: + +- **Identity Adaptor**: Starting point for building transformations +- **Transpose Adaptor**: Dimension reordering with permutation patterns +- **Single-Stage Adaptors**: Custom transform chains with precise control +- **Chained Adaptors**: Complex multi-stage transformation pipelines +- **Transform Addition**: Extending existing adaptors with new transforms + +Core concepts to remember: + +- **Bottom/Top Dimensions**: Input and output coordinate spaces +- **Hidden Dimensions**: Internal coordinate mappings between transforms +- **Transform Chains**: Sequential application of multiple transforms +- **Coordinate Transformation**: Bidirectional mapping between coordinate spaces +- **Nested Transforms**: Complex multi-level transformation hierarchies + +Key C++ Patterns in Composable Kernel +-------------------------------------- + +1. **Descriptor-Based Adaptors**: In CK, adaptors are typically embedded within :ref:`tensor descriptors ` rather than created separately +2. **Compile-Time Optimization**: All transformations are resolved at compile time for zero overhead +3. **Type Safety**: Template metaprogramming ensures coordinate transformations are type-safe +4. **GPU Optimization**: Transform chains are designed for efficient GPU memory access patterns. See :ref:`ck_tile_lds_bank_conflicts` for LDS optimization. + +TensorAdaptors bridge the gap between low-level transforms and high-level tensor operations, providing the flexibility to create advanced data layouts and access patterns that are essential for efficient GPU computing. They build upon the foundation of :ref:`BufferViews ` and :ref:`TensorViews ` to provide complex transformation capabilities. + +Next Steps +---------- + +- :ref:`ck_tile_descriptors` - How adaptors combine with element space to form complete tensor descriptors +- :ref:`ck_tile_transforms` - Individual transform types and their properties +- :ref:`ck_tile_tile_window` - How adaptors enable efficient data loading patterns +- :ref:`ck_tile_space_filling_curve` - Advanced coordinate mapping techniques for cache optimization +- :ref:`ck_tile_static_distributed_tensor` - How adaptors help manage distributed tensor storage diff --git a/docs/conceptual/ck_tile/buffer_views.rst b/docs/conceptual/ck_tile/buffer_views.rst new file mode 100644 index 0000000000..14b8309504 --- /dev/null +++ b/docs/conceptual/ck_tile/buffer_views.rst @@ -0,0 +1,443 @@ +.. meta:: + :description: Composable Kernel CK Tile buffer views + :keywords: composable kernel, CK, CK Tile, ROCm, API, buffer view, raw memory + +.. _ck_tile_buffer_views: + +CK Tile buffer view +======================= + +Buffer view is an abstraction that provides structured access to memory. The ``buffer_view`` class is exposed in ``include/ck_tile/core/tensor/buffer_view.hpp``. + +Buffer view serves as the foundation for :ref:`ck_tile_tensor_views`. BufferView handles memory addressing and type safety, while TensorView builds upon this to add multi-dimensional coordinates (shape and strides). + + +Buffer view provides the following advantages: + +* A unified interface across global, shared, and register memory +* Address spaces encoded in types, taking advantage of compile-time type checking +* Configurable handling of invalid values, out-of-bounds operations, and conditional access patterns +* Atomic operations for parallel algorithms +* AMD GPU-specific optimizations +* Automatic application of appropriate memory ordering constraints and cache control directives based on the target address space and operation type + + +[TO DO: do we want to say more about these items? There wasn't a lot of detail in the original text, so I put them in a list for now] + + + +Address Space Usage Patterns +---------------------------- + +[TO DO: explain in words what the diagram shows] +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart TB + subgraph CF ["Compute Flow"] + direction LR + GM1["Global Memory
Input Data"] --> LDS["LDS
Tile Cache"] + LDS --> VGPR["VGPR
Working Set"] + VGPR --> Compute["Compute
Operations"] + Compute --> VGPR + VGPR --> LDS2["LDS
Reduction"] + LDS2 --> GM2["Global Memory
Output Data"] + end + + subgraph UP ["Usage Pattern"] + direction LR + P1["1. Load tile from Global → LDS"] + P2["2. Load working set LDS → VGPR"] + P3["3. Compute in VGPR"] + P4["4. Store results VGPR → LDS"] + P5["5. Reduce in LDS"] + P6["6. Write final LDS → Global"] + + P1 --> P2 --> P3 --> P4 --> P5 --> P6 + end + + CF ~~~ UP + + style GM1 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style LDS fill:#fed7aa,stroke:#f59e0b,stroke-width:2px + style VGPR fill:#d1fae5,stroke:#10b981,stroke-width:2px + style Compute fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + + +.. image:: diagrams/buffer_views_1.svg + :alt: Diagram + :align: center + + +Basic Creation +~~~~~~~~~~~~~~ + +[TO DO: remove "modern C++ template metaprogramming" and "zero-overhead abstraction"] + +[TO DO: might want to move the implementation details to a separate section under "reference"] + + +.. code-block:: cpp + + #include + #include + + // Create buffer view in C++ + __device__ void example_buffer_creation() + { + // Static array in global memory + float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + constexpr index_t buffer_size = 8; + + // Create buffer view for global memory + // Template parameters: + auto buffer_view = make_buffer_view( + data, // pointer to data + buffer_size // number of elements + ); + + + // Implementation detail: The actual C++ template is: + // template + // struct buffer_view + + // Alternative: Create with explicit type + using buffer_t = buffer_view; + buffer_t explicit_buffer{data, number{}}; + + // Access properties at compile time + constexpr auto size = buffer_view.get_buffer_size(); + constexpr auto space = buffer_view.get_address_space(); + + // The buffer_view type encodes: + // - Data type (float) + // - Address space (global memory) + // - Size (known at compile time for optimization) + static_assert(size == 8, "Buffer size should be 8"); + static_assert(space == address_space_enum::global, "Should be global memory"); + } + +[TO DO: add details and remove unnecessary comments; the "implementation detail" comment can be moved out and either placed outside and explained further, or just removed, depending on what we want to do] + +[TO DO: might want to put this implementation detail in the reference section] + +Buffer view uses two modes, zero value mode and custom value mode, that can prevent serialization during bounds checking. + +Zero value mode returns zero without branching when an access falls outside the valid buffer range. This is useful in convolutions where out-of-bounds accesses correspond to zero-padding. + +Custom value mode returns a custom value without branching when an access falls outside the valid buffer range. Custom value mode accommodates algorithms that require specific values for boundary conditions. + +[TO DO: there were two examples of custom value mode that I removed. I removed them because unlike for zero value mode where the example was convolution, the example was vague in custom value. Is there a more specific example of where custom value would be used?] + +.. code-block:: cpp + + // Basic buffer view creation with automatic zero for invalid elements + void basic_creation_example() { + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + // Create global memory buffer view + auto buffer_view = make_buffer_view(data, buffer_size); + } + + // Custom invalid value mode + void custom_invalid_value_example() { + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + float custom_invalid = 13.0f; + + // Create buffer view with custom invalid value + auto buffer_view = make_buffer_view( + data, buffer_size, custom_invalid); + } + + +When ``InvalidElementUseNumericalZeroValue`` is set to true, the system uses zero value mode for out of bounds checking. When ``InvalidElementUseNumericalZeroValue`` is set to false, custom value mode is used. Zero value mode is used by default. + +.. note:: + + Zero or custom invalid value is only returned for complete invalid values or out of bound access, for example when the first address of the vector is invalid. Partial out of bounds access during vector reads will not return useful results. + +.. code-block:: cpp + + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + float custom_invalid = 13.0f; + + // Create global memory buffer view with zero invalid value mode (default) + auto buffer_view = make_buffer_view(data, buffer_size, custom_invalid); + + // Invalid element access with is_valid_element=false + // Returns custom_invalid due to custom invalid value mode + auto invalid_value = buffer_view.template get(0, 0, false); + printf("Invalid element: %.1f\n", invalid_value.get(0)); + + // Out of bounds access - AMD buffer addressing handles bounds checking + // Will return custom_invalid when accessing beyond buffer_size + auto oob_value = buffer_view.template get(0, 100, true); + printf("Out of bounds: %.1f\n", oob_value.get(0)); + + + + + +Get Operations +-------------- + +[TO DO: might want to put this implementation detail in the reference section] + +The signature for the ``buffer_view`` ``get()`` takes four parameters: + +``i``: the primary offset into the buffer expressed in terms of elements of type T rather than raw bytes. + +``linear_offset``: [TO DO: what is this?] + +``is_valid_element``: [TO DO: what is this?] + +[TO DO: the last param, that's the out of bounds handling, yes? +.. code:: cpp + + get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) + + +[TO DO: need some context around the code] + +[TO DO: code chunks need to have detail and explanation so that the reader can see what they're trying to demonstrate.] + + +.. code-block:: cpp + + // Create buffer view + float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + auto buffer_view = make_buffer_view(data, 8); + + // Simple get - compile-time bounds checking when possible + auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view + float value = value_buf.get(0); //get the value from the buffer + + // Get with valid flag - branchless conditional access + bool valid_flag = false; + value_buf = buffer_view.template get(0,1,valid_flag); + value = value_buf.get(0); + // Returns 0 valid_flag is false + + // vectorized get + using float2 = ext_vector_t; + auto vector_buf = buffer_view.template get(0, 0, true); + // Loads 2 floats in a single instruction + float val1 = vector_buf.get(0); + float val2 = vector_buf.get(1); + } + +``ext_vector_t`` enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. + +[TO DO: what is it actually doing? When does one use scalars vs vectors? Is it application specific or are there ] + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Scalar Access (4 instructions)" + S1["Load float[0]"] --> R1["Register 1"] + S2["Load float[1]"] --> R2["Register 2"] + S3["Load float[2]"] --> R3["Register 3"] + S4["Load float[3]"] --> R4["Register 4"] + end + + subgraph "Vectorized Access (1 instruction)" + V1["Load float4[0]"] --> VR["Vector Register
(4 floats)"] + end + + subgraph "Performance Impact" + Perf["4x fewer instructions
Better memory bandwidth
Reduced latency"] + end + + R1 & R2 & R3 & R4 --> Perf + VR --> Perf + + style S1 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style S2 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style S3 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style S4 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style V1 fill:#d1fae5,stroke:#10b981,stroke-width:2px + style Perf fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + + + + + + +.. image:: diagrams/buffer_views_2.svg + :alt: Diagram + :align: center + +Understanding BufferView Indexing +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +[TO DO: an explanation of the diagram is needed] + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "Input Parameters" + Offset["Offset
(e.g., 5)"] + ValidFlag["Valid Flag
(optional)"] + end + + subgraph "Processing" + BoundsCheck{{"Bounds Check
offset < buffer_size?"}} + FlagCheck{{"Flag Check
valid_flag == True?"}} + Access["Access Memory
buffer[offset]"] + end + + subgraph "Output" + ValidResult["Valid Result
Return value"] + Invalid["Invalid Result
Return 0 or default"] + end + + Offset --> BoundsCheck + ValidFlag --> FlagCheck + + BoundsCheck -->|Yes| FlagCheck + BoundsCheck -->|No| Invalid + + FlagCheck -->|Yes| Access + FlagCheck -->|No| Invalid + + Access --> ValidResult + + style Offset fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style ValidFlag fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style ValidResult fill:#d1fae5,stroke:#10b981,stroke-width:2px + style Invalid fill:#fee2e2,stroke:#ef4444,stroke-width:2px + + + + + + +.. image:: diagrams/buffer_views_3.svg + :alt: Diagram + :align: center + + + +Update Operations +----------------- + +Update operations modify the buffer content. The ``set()`` method writes a value to a specific location. + +.. code-block:: cpp + + void scalar_set_operations_example() { + + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + // Create global memory buffer view + auto buffer_view = make_buffer_view(data, buffer_size); + + // Basic set: set(i, linear_offset, is_valid_element, value) + // Sets element at position i + linear_offset = 0 + 2 = 2 + buffer_view.template set(0, 2, true, 99.0f); + + // Invalid write with is_valid_element=false (ignored) + buffer_view.template set(0, 3, false, 777.0f); + + // Out of bounds write - handled safely by AMD buffer addressing + buffer_view.template set(0, 100, true, 555.0f); + + // Vector set + using float2 = ext_vector_t; + float2 pair_values{100.0f, 200.0f}; + buffer_view.template set(0, 5, true, pair_values); + } + +Atomic Operations +----------------- + +[TO DO: this needs information] + +Atomic vs Non-Atomic Operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Non-Atomic Operation (Race Condition)" + NA1["Thread 1: Read value (10)"] --> NA2["Thread 1: Add 5 (15)"] + NA3["Thread 2: Read value (10)"] --> NA4["Thread 2: Add 3 (13)"] + NA2 --> NA5["Thread 1: Write 15"] + NA4 --> NA6["Thread 2: Write 13"] + NA5 & NA6 --> NA7["Final value: 13 ❌
(Lost update from Thread 1)"] + end + + subgraph "Atomic Operation (Thread-Safe)" + A1["Thread 1: atomic_add(5)"] --> A2["Hardware ensures
serialization"] + A3["Thread 2: atomic_add(3)"] --> A2 + A2 --> A4["Final value: 18 ✓
(Both updates applied)"] + end + + style NA7 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style A4 fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + + +.. image:: diagrams/buffer_views_4.svg + :alt: Diagram + :align: center + +C++ Atomic Operations +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + __device__ void example_atomic_operations() + { + // Shared memory for workgroup-level reductions + __shared__ float shared_sum[256]; + auto shared_buffer_view = make_buffer_view( + shared_sum, 256 + ); + + // Initialize shared memory + if (threadIdx.x < 256) { + shared_buffer_view.template set(threadIdx.x, 0.0f, true); + } + __syncthreads(); + + // Each thread atomically adds to shared memory + auto my_value = static_cast(threadIdx.x); + shared_buffer_view.template update(0,0,true,my_value); + + // Atomic max for finding maximum value + shared_buffer_view.template update(0,1,true,my_value); + + __syncthreads(); + } diff --git a/docs/conceptual/ck_tile/cache_flushing_benchmarking.rst b/docs/conceptual/ck_tile/cache_flushing_benchmarking.rst new file mode 100644 index 0000000000..2866ba0c9f --- /dev/null +++ b/docs/conceptual/ck_tile/cache_flushing_benchmarking.rst @@ -0,0 +1,390 @@ +=================================== +Cache Flushing for GPU Benchmarking +=================================== + +Overview +======== + +When benchmarking GPU kernels, accurate performance measurements require understanding and controlling cache behavior. Running a kernel multiple times with the same input data can lead to artificially fast results due to **cache hits**, where data and instructions are served from fast GPU cache rather than slow High Bandwidth Memory (HBM). + +Composable Kernel provides two complementary mechanisms to ensure realistic "cold cache" performance measurements: + +1. **Instruction Cache Flushing** - Invalidates cached GPU instructions +2. **Rotating Memory Buffers** - Cycles through multiple data buffer copies at different memory addresses + +This document explains how these mechanisms work and how to use them in benchmarks. + +The Problem: Hot vs. Cold Cache +================================ + +GPU Memory Hierarchy +-------------------- + +GPUs have a multi-level cache hierarchy: + +.. code-block:: text + + Fast → Slow, Small → Large + + ┌─────────────────┐ + │ Register File │ ~1 cycle + ├─────────────────┤ + │ L1 I-Cache │ ~4 cycles ← Instruction cache + ├─────────────────┤ + │ L1 Data Cache │ ~4 cycles ← Data cache + ├─────────────────┤ + │ L2 Cache │ ~50 cycles + ├─────────────────┤ + │ HBM (VRAM) │ ~400 cycles + └─────────────────┘ + +Cache Behavior Without Flushing +-------------------------------- + +When running a kernel repeatedly without cache management: + +.. code-block:: text + + Run 1: [Cache MISS] → Fetch from HBM → 400 cycles → 5.2ms + Run 2: [Cache HIT!] → Read from L1/L2 → 4 cycles → 3.8ms ← Artificially fast! + Run 3: [Cache HIT!] → Read from L1/L2 → 4 cycles → 3.8ms + ... + Average: 4.1ms (misleading - not representative of real-world performance) + +This leads to: + +- ✗ Inflated performance numbers +- ✗ Inconsistent timing between first and subsequent runs +- ✗ Unfair comparisons between different kernels +- ✗ Misleading optimization decisions + +Solution 1: Instruction Cache Flushing +======================================= + +What is Instruction Cache? +--------------------------- + +The **instruction cache (I-cache)** is a small, fast memory on each GPU compute unit that stores recently executed machine code instructions. When a thread needs to execute an instruction: + +1. The **Program Counter (PC)** holds the instruction's memory address +2. The GPU checks if that address exists in the I-cache +3. **Cache HIT**: Instruction read instantly from I-cache (~4 cycles) +4. **Cache MISS**: Instruction fetched from HBM (~400 cycles), then cached + +How It Works +------------ + +The GPU uses **address-based caching**: when you launch the same kernel multiple times, the kernel code resides at the same memory address, allowing the I-cache to serve cached instructions. + +.. code-block:: text + + First Kernel Run: + PC = 0x7F8A0000 → I-Cache lookup → MISS → Fetch from HBM → Cache it + + Second Kernel Run (without flush): + PC = 0x7F8A0000 → I-Cache lookup → HIT! → Read from cache (fast!) + + Second Kernel Run (with flush): + PC = 0x7F8A0000 → I-Cache lookup → MISS → Fetch from HBM again + +The ``flush_icache()`` Function +-------------------------------- + +Located in ``include/ck_tile/host/flush_icache.hpp``: + +.. code-block:: cpp + + namespace ck_tile { + // GPU kernel to invalidate instruction cache for accurate benchmarking. + static __global__ void flush_cache() + { + asm __volatile__("s_icache_inv \n\t" // Invalidate I-cache + "s_nop 0 \n\t" // Wait cycles (16 NOPs) + "s_nop 0 \n\t" + // ... 14 more NOPs + "s_nop 0 \n\t" :: + :); + } + } + +**Key Components:** + +- ``s_icache_inv``: AMD GPU instruction that invalidates the L1 instruction cache on the current compute unit +- ``s_nop 0`` (×16): No-operation instructions (NOPs) that create a 16-cycle delay to ensure cache invalidation completes before the kernel exits + +**Why 16 NOPs?** + +The ``s_icache_inv`` instruction is **asynchronous**: it initiates cache invalidation but doesn't wait for completion. Without the NOPs, the kernel might exit before the flush finishes, leading to race conditions and incomplete cache invalidation. + +Launching the Flush Kernel +--------------------------- + +From ``include/ck_tile/host/rotating_buffers.hpp``: + +.. code-block:: cpp + + inline void flush_icache() + { + hipDeviceProp_t deviceProps; + HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); + + // Over-provision blocks to ensure all CUs execute the flush instruction. + // With imperfect scheduling, launching exactly 1 block per CU doesn't guarantee coverage. + // 60x over-provisioning provides statistical certainty that every CU gets at least one block. + constexpr int32_t blocks_per_cu = 60; + int32_t gpu_block3 = deviceProps.multiProcessorCount * blocks_per_cu; + + ck_tile::flush_cache<<>>(); + HIP_CHECK_ERROR(hipGetLastError()); + } + +**Why 60× Over-provisioning?** + +The I-cache is **per-compute-unit** (CU). To flush all CUs, we must ensure every CU executes at least one instance of ``s_icache_inv``. + +- Launching exactly 1 block per CU doesn't guarantee 1:1 mapping due to GPU scheduler behavior +- Launching 60 blocks per CU provides statistical certainty that every CU receives work +- For a 120-CU GPU: 120 × 60 = 7,200 blocks × 64 threads = 460,800 total threads + +This ensures comprehensive instruction cache flushing across all compute units. + +Solution 2: Rotating Memory Buffers +==================================== + +What is Data Cache? +------------------- + +While I-cache stores instructions, **data cache** (L1 data, L2) stores matrix data (inputs A, B and output C). When a kernel reads the same matrix repeatedly, the data is served from cache rather than HBM. + +The RotatingMemWrapper Struct +------------------------------ + +Located in ``include/ck_tile/host/rotating_buffers.hpp``: + +.. code-block:: cpp + + template + struct RotatingMemWrapper + { + RotatingMemWrapper(const void* a_ptr_, + const void* b_ptr_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_); + + void Next(); // Rotate to next buffer copy + ~RotatingMemWrapper() noexcept; // Cleanup + }; + +**Purpose**: Prevents data cache reuse by cycling through multiple copies of input matrices at different memory addresses. + +How It Works +------------ + +**Constructor: Create Buffer Copies** + +.. code-block:: cpp + + RotatingMemWrapper(a_ptr, b_ptr, rotating_count=3, size_a, size_b) + { + // Store original buffer pointers as first entry + p_a_grids.push_back(a_ptr); + p_b_grids.push_back(b_ptr); + + // Create (rotating_count - 1) additional copies at different memory addresses + for(size_t i = 1; i < rotating_count; i++) + { + void* pADeviceBuf; + hipMalloc(&pADeviceBuf, size_a); + hipMemcpy(pADeviceBuf, p_a_grids[0], size_a, hipMemcpyDeviceToDevice); + p_a_grids.push_back(pADeviceBuf); + + // Same for B matrix... + } + } + +Result: + +.. code-block:: text + + GPU Memory: + ┌─────────────────────────┐ + │ Matrix A (original) │ Address: 0x1000 + │ Matrix A (copy 1) │ Address: 0x2000 + │ Matrix A (copy 2) │ Address: 0x3000 + │ Matrix B (original) │ Address: 0x4000 + │ Matrix B (copy 1) │ Address: 0x5000 + │ Matrix B (copy 2) │ Address: 0x6000 + └─────────────────────────┘ + +**Next(): Rotate to Next Buffer** + +.. code-block:: cpp + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; // Cycle: 0,1,2,0,1,2,... + a_ptr = p_a_grids[idx]; + b_ptr = p_b_grids[idx]; + } + } + +Usage in benchmarking loop: + +.. code-block:: text + + Iteration 1: Next() → Use buffers at 0x1000, 0x4000 → Kernel reads → Cache miss + Iteration 2: Next() → Use buffers at 0x2000, 0x5000 → Kernel reads → Cache miss + Iteration 3: Next() → Use buffers at 0x3000, 0x6000 → Kernel reads → Cache miss + Iteration 4: Next() → Use buffers at 0x1000, 0x4000 → Kernel reads → Cache miss + ... + +By the time the buffers cycle back to the first copy, the cache has likely evicted the old data. + +**Destructor: Cleanup** + +.. code-block:: cpp + + ~RotatingMemWrapper() noexcept + { + // Restore original buffer pointers + a_ptr = p_a_grids[0]; + b_ptr = p_b_grids[0]; + + // Free extra buffer copies (index 0 is original, don't free it) + for(size_t i = 1; i < rotating_count; i++) + { + hipFree(p_a_grids[i]); + hipFree(p_b_grids[i]); + } + } + +Using Cache Flushing in Practice +================================= + +Command Line Argument +--------------------- + +The ``flush_cache`` command-line argument controls whether cache flushing is enabled: + +.. code-block:: bash + + # Enable cache flushing (cold cache benchmarking) + ./gemm_example --flush_cache=1 --rotating_count=3 + + # Disable cache flushing (hot cache benchmarking) + ./gemm_example --flush_cache=0 + +In ``run_gemm_quant_example.inc``: + +.. code-block:: cpp + + bool flush_cache = arg_parser.get_bool("flush_cache"); + int rotating_count = arg_parser.get_int("rotating_count"); + + // Pass to stream_config + ck_tile::stream_config{ + nullptr, // stream + true, // time_kernel + 1, // log_level + n_warmup, // cold_niters (warmup iterations) + n_repeat, // nrepeat (timed iterations) + true, // is_gpu_timer + flush_cache, // flush_cache_ ← Controls cache flushing + rotating_count // rotating_count_ ← Number of buffer copies + } + +Integration with Timing Loop +----------------------------- + +The ``launch_kernel_time_mask`` function integrates both mechanisms: + +.. code-block:: cpp + + // From include/ck_tile/host/kernel_launch.hpp + template + float launch_kernel_time_mask(const stream_config& s, + PreprocessFunc preprocess, + Callables&&... callables) + { + // Timing loop (simplified) + for(int i = 0; i < s.nrepeat_; i++) + { + preprocess(); // 1. Flush I-cache + rotate buffers + callables_func(); // 2. Launch kernel + } + + return average_time; + } + +Complete Example +---------------- + +From ``example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc``: + +.. code-block:: cpp + + // Setup rotating memory wrapper + RotatingMemWrapper rotating_mem( + a_ptr, b_ptr, rotating_count, size_a, size_b); + + // Define preprocessing: flush I-cache + rotate buffers + auto preprocess = [&]() { + if(flush_cache) { + flush_icache(); // Invalidate instruction cache + rotating_mem.Next(); // Switch to next buffer copy + } + }; + + // Define kernel launch + auto kernel_launch = [&]() { + gemm_kernel<<>>(a_ptr, b_ptr, c_ptr, M, N, K); + }; + + // Benchmark with cache control + float avg_time = launch_kernel_time_mask( + stream_config, // Config with flush_cache and rotating_count + preprocess, // Flush + rotate before each iteration + kernel_launch // Kernel to benchmark + ); + +Execution Flow +-------------- + +With ``flush_cache=true`` and ``rotating_count=3``, ``nrepeat=100``: + +.. code-block:: text + + Warmup Phase (n_warmup iterations): + - Run kernel without timing + - Prime GPU, warm up scheduler + + Timed Phase (100 iterations): + Iteration 1: flush_icache() → rotating_mem.Next() → Use buffer copy 0 → kernel() → Measure + Iteration 2: flush_icache() → rotating_mem.Next() → Use buffer copy 1 → kernel() → Measure + Iteration 3: flush_icache() → rotating_mem.Next() → Use buffer copy 2 → kernel() → Measure + Iteration 4: flush_icache() → rotating_mem.Next() → Use buffer copy 0 → kernel() → Measure + ... + Iteration 100: flush_icache() → rotating_mem.Next() → Use buffer copy 1 → kernel() → Measure + + Return: Average time per iteration (excluding preprocess overhead) + +References +========== + +Related Files +------------- + +- ``include/ck_tile/host/flush_icache.hpp`` - I-cache flush kernel implementation +- ``include/ck_tile/host/rotating_buffers.hpp`` - RotatingMemWrapper implementation +- ``include/ck_tile/host/kernel_launch.hpp`` - Timing loop integration + +Conclusion +========== + +Accurate GPU kernel benchmarking requires careful control of cache behavior. The combination of **instruction cache flushing** (``flush_icache``) and **rotating memory buffers** (``RotatingMemWrapper``) ensures realistic "cold cache" performance measurements that represent real-world application behavior. + +By understanding and utilizing these mechanisms through the ``flush_cache`` command-line argument, you can obtain trustworthy performance data for optimization decisions and fair kernel comparisons. + diff --git a/docs/conceptual/ck_tile/convert_mermaid_to_svg.py b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py new file mode 100644 index 0000000000..1d62405e53 --- /dev/null +++ b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +""" +Script to convert all mermaid diagrams in CK Tile docs to SVGs. +This script: +1. Finds all mermaid blocks in RST files +2. Converts them to SVG using mmdc +3. Updates RST files to use SVG images with commented mermaid source +""" + +import os +import re +import subprocess +import tempfile +from pathlib import Path + +# Configuration +DOCS_DIR = Path(__file__).parent +DIAGRAMS_DIR = DOCS_DIR / "diagrams" +RST_FILES = [ + "convolution_example.rst", + "encoding_internals.rst", + "lds_index_swapping.rst", + "space_filling_curve.rst", + "sweep_tile.rst", + "tensor_coordinates.rst", + "thread_mapping.rst", + "static_distributed_tensor.rst", + "load_store_traits.rst", + "tile_window.rst", + "transforms.rst", + "descriptors.rst", + "coordinate_movement.rst", + "adaptors.rst", + "introduction_motivation.rst", + "buffer_views.rst", + "tensor_views.rst", + "coordinate_systems.rst", + "tile_distribution.rst", +] + +# Pattern to find mermaid blocks (can be indented with 3 spaces for commented blocks) +MERMAID_PATTERN = re.compile( + r"^(?: )?\.\. mermaid::\s*\n((?:(?:\n| .*))*)", re.MULTILINE +) + + +def extract_mermaid_content(block): + """Extract the actual mermaid code from the block, removing RST indentation.""" + lines = block.split("\n") + # Remove the leading spaces (RST indentation) + content_lines = [] + for line in lines: + if line.startswith(" "): + content_lines.append(line[3:]) # Remove 3 spaces + elif line.strip() == "": + content_lines.append("") + return "\n".join(content_lines).strip() + + +def generate_diagram_name(file_path, diagram_index, total_in_file): + """Generate a descriptive name for the diagram.""" + base_name = file_path.stem + if total_in_file == 1: + return f"{base_name}.svg" + else: + return f"{base_name}_{diagram_index + 1}.svg" + + +def convert_mermaid_to_svg(mermaid_code, output_path): + """Convert mermaid code to SVG using mmdc.""" + # Create a temporary file for the mermaid code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".mmd", delete=False, encoding="utf-8" + ) as tmp: + tmp.write(mermaid_code) + tmp_path = tmp.name + + try: + # Run mmdc to convert to SVG (use shell=True on Windows for .cmd files) + subprocess.run( + [ + "mmdc", + "-i", + tmp_path, + "-o", + str(output_path), + "-t", + "neutral", + "-b", + "transparent", + ], + capture_output=True, + text=True, + check=True, + shell=True, # Required for Windows .cmd files + ) + print(f" ✓ Generated: {output_path.name}") + return True + except subprocess.CalledProcessError as e: + print(f" ✗ Error converting diagram: {e.stderr}") + return False + finally: + # Clean up temp file + os.unlink(tmp_path) + + +def update_rst_file(file_path, diagrams_info): + """Update RST file to replace mermaid blocks with commented source + image reference.""" + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Sort diagrams by position (reverse order to maintain positions) + diagrams_info.sort(key=lambda x: x["position"], reverse=True) + + for info in diagrams_info: + # Find the mermaid block + match = info["match"] + start_pos = match.start() + end_pos = match.end() + + # Create the replacement text + mermaid_block = match.group(0) + + # Create commented mermaid block + commented_lines = [ + ".. ", + " Original mermaid diagram (edit here, then run update_diagrams.py)", + " ", + ] + for line in mermaid_block.split("\n"): + commented_lines.append(f" {line}") + + # Add image reference + svg_rel_path = f"diagrams/{info['svg_name']}" + image_block = [ + "", + f".. image:: {svg_rel_path}", + " :alt: Diagram", + " :align: center", + "", + ] + + replacement = "\n".join(commented_lines + image_block) + + # Replace in content + content = content[:start_pos] + replacement + content[end_pos:] + + # Write back + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + print(f" ✓ Updated: {file_path.name}") + + +def process_file(file_path): + """Process a single RST file.""" + print(f"\nProcessing {file_path.name}...") + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Find all mermaid blocks + matches = list(MERMAID_PATTERN.finditer(content)) + + if not matches: + print(" No mermaid diagrams found.") + return + + print(f" Found {len(matches)} diagram(s)") + + diagrams_info = [] + + # Process each mermaid block + for idx, match in enumerate(matches): + mermaid_content = extract_mermaid_content(match.group(1)) + svg_name = generate_diagram_name(file_path, idx, len(matches)) + svg_path = DIAGRAMS_DIR / svg_name + + # Convert to SVG + if convert_mermaid_to_svg(mermaid_content, svg_path): + diagrams_info.append( + {"match": match, "svg_name": svg_name, "position": match.start()} + ) + + # Update the RST file + if diagrams_info: + update_rst_file(file_path, diagrams_info) + + +def main(): + """Main function.""" + print("CK Tile Mermaid to SVG Converter") + print("=" * 50) + + # Verify mmdc is available + try: + subprocess.run( + ["mmdc", "--version"], capture_output=True, check=True, shell=True + ) + except (subprocess.CalledProcessError, FileNotFoundError): + print("Error: mermaid-cli (mmdc) not found. Please install it:") + print(" npm install -g @mermaid-js/mermaid-cli") + return 1 + + # Ensure diagrams directory exists + DIAGRAMS_DIR.mkdir(parents=True, exist_ok=True) + + # Process each file + for rst_file in RST_FILES: + file_path = DOCS_DIR / rst_file + if file_path.exists(): + process_file(file_path) + else: + print(f"\n⚠ Warning: {rst_file} not found") + + print("\n" + "=" * 50) + print("✓ Conversion complete!") + print(f"SVG files saved to: {DIAGRAMS_DIR}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/docs/conceptual/ck_tile/convert_raw_html_to_commented.py b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py new file mode 100644 index 0000000000..e90bf9def0 --- /dev/null +++ b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +"""Convert raw HTML mermaid blocks to commented format for SVG conversion.""" + +import os +import re + + +def convert_raw_html_to_commented(content): + """Convert raw HTML mermaid blocks to commented mermaid format.""" + + # Pattern to match raw HTML mermaid blocks + pattern = r'\.\. raw:: html\n\n
]*>\n(.*?)\n
' + + def replace_block(match): + mermaid_code = match.group(1) + # The mermaid code in HTML has 3-space indentation, keep it + # but add 3 more spaces for .. mermaid:: indentation + mermaid_lines = mermaid_code.split("\n") + properly_indented = [] + for line in mermaid_lines: + if line.strip(): # Non-empty line + # Line already has 3 spaces from HTML, add 3 more for mermaid block + properly_indented.append(" " + line) + else: + properly_indented.append("") + + indented_code = "\n".join(properly_indented) + + # Create commented format matching the expected pattern + commented = f""".. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + +{indented_code} + + +""" + return commented + + return re.sub(pattern, replace_block, content, flags=re.DOTALL) + + +def main(): + """Process files with raw HTML mermaid blocks.""" + + files_to_convert = [ + "introduction_motivation.rst", + "buffer_views.rst", + "tensor_views.rst", + "coordinate_systems.rst", + "tile_distribution.rst", + ] + + converted_files = [] + + for filename in files_to_convert: + if not os.path.exists(filename): + print(f"Skipping {filename} - not found") + continue + + with open(filename, "r", encoding="utf-8") as f: + original = f.read() + + converted = convert_raw_html_to_commented(original) + + if converted != original: + with open(filename, "w", encoding="utf-8") as f: + f.write(converted) + + blocks_converted = original.count(".. raw:: html") + converted_files.append((filename, blocks_converted)) + print(f"✓ Converted {filename}: {blocks_converted} blocks") + else: + print(f" {filename}: no raw HTML blocks found") + + print("\n=== CONVERSION COMPLETE ===") + print(f"Files converted: {len(converted_files)}") + print(f"Total blocks: {sum(c for _, c in converted_files)}") + print("\nNext: Run convert_mermaid_to_svg.py to generate SVG files") + + +if __name__ == "__main__": + main() diff --git a/docs/conceptual/ck_tile/convolution_example.rst b/docs/conceptual/ck_tile/convolution_example.rst new file mode 100644 index 0000000000..a981ae04da --- /dev/null +++ b/docs/conceptual/ck_tile/convolution_example.rst @@ -0,0 +1,567 @@ +.. meta:: + :description: CK Tile convolution implementation example + :keywords: CK Tile, convolution, im2col, tensor descriptors, GPU optimization + +.. _ck_tile_convolution_example: + +***************************************** +Convolution Implementation with CK Tile +***************************************** + +Overview +======== + +This section covers how CK Tile's :ref:`tensor descriptor ` system enables efficient convolution implementations on GPUs. Convolution operations are fundamental in deep learning, and understanding their optimization reveals how high-performance libraries achieve their efficiency. This section progresses from a naive implementation to an optimized approach using tensor descriptors, showing how they enable efficient memory access patterns for GPU acceleration. + +The key insight is that convolution can be transformed from a complex nested loop operation into a highly parallel matrix multiplication through the image to column (im2col) transformation. CK Tile's tensor descriptors provide the perfect abstraction for implementing this transformation efficiently without data duplication. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Convolution Process" + I["Input Image
6×6"] + K["Kernel
3×3"] + SW["Sliding Window
Extract 3×3 patches"] + DP["Dot Product
Element-wise multiply & sum"] + O["Output
4×4"] + end + + subgraph "Im2col Optimization" + W["Windows Matrix
16×9
(all patches)"] + KF["Kernel Flattened
9×1"] + MM["Matrix Multiply
W @ K"] + OF["Output Flattened
16×1"] + end + + I --> SW + K --> DP + SW --> DP + DP --> O + + SW --> W + K --> KF + W --> MM + KF --> MM + MM --> OF + OF --> O + + style I fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style O fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style MM fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + + +.. image:: diagrams/convolution_example.svg + :alt: Diagram + :align: center + +.. image:: diagrams/convolution_example.svg + :alt: Diagram + :align: center + +Understanding Sliding Windows +============================= + +Before diving into convolution, it's crucial to understand how sliding windows work. In convolution, overlapping patches need to be extracted from the input image. Traditional approaches would copy these patches, but CK Tile uses :ref:`tensor descriptors ` to create efficient :ref:`views ` without data duplication. + +Simple Tiling Example +--------------------- + +Non-overlapping tiles: + +.. code-block:: cpp + + // Create a 6x6 matrix tiled into 2x2 blocks + template + struct SimpleTiling { + static constexpr index_t kMatrixSize = 6; + static constexpr index_t kTileSize = 2; + static constexpr index_t kNumTiles = kMatrixSize / kTileSize; + + // Original matrix: shape=(6, 6), strides=(6, 1) + // Tiled view: shape=(3, 3, 2, 2), strides=(12, 2, 6, 1) + // See :ref:`ck_tile_descriptors` for descriptor details + using TileDescriptor = TensorDescriptor< + Sequence, + Sequence<12, 2, 6, 1> + >; + + __device__ void demonstrate() { + // To move to next tile row: skip 2 matrix rows = 6 × 2 = 12 + // To move to next tile col: skip 2 matrix cols = 1 × 2 = 2 + // Within tile: use original strides (6, 1) + } + }; + +The key insight is understanding **strides**, the number of elements to skip to move to the next element in each dimension. For non-overlapping tiles, we skip by ``tile_size`` in the outer dimensions. + +Overlapping Windows for Convolution +------------------------------------ + +For convolution, overlapping windows that slide by one element are needed: + +.. code-block:: cpp + + // Extract 3x3 overlapping windows from a 6x6 image + template + struct ConvolutionWindows { + static constexpr index_t H = 6; // Image height + static constexpr index_t W = 6; // Image width + static constexpr index_t K = 3; // Kernel size + static constexpr index_t OutH = H - K + 1; // Output height = 4 + static constexpr index_t OutW = W - K + 1; // Output width = 4 + + // Windows descriptor: shape=(4, 4, 3, 3), strides=(6, 1, 6, 1) + using WindowDescriptor = TensorDescriptor< + Sequence, + Sequence // Key: stride by 1 for overlap! + >; + + __device__ DataType extract_window(const DataType* image, + index_t out_i, index_t out_j, + index_t k_i, index_t k_j) { + WindowDescriptor desc; + index_t offset = desc.calculate_offset({out_i, out_j, k_i, k_j}); + return image[offset]; + } + }; + +The stride pattern ``[W, 1, W, 1]`` creates sliding windows: + +- Moving one step in output row: jump ``W`` elements (one image row) +- Moving one step in output col: jump ``1`` element (one image column) +- Within each window: same strides to access the 3×3 patch + +Naive Convolution Implementation +================================ + +A straightforward implementation for reference: + +.. code-block:: cpp + + template + __global__ void naive_convolution_kernel( + const DataType* __restrict__ input, + const DataType* __restrict__ kernel, + DataType* __restrict__ output, + index_t H, index_t W, index_t K) + { + index_t out_h = H - K + 1; + index_t out_w = W - K + 1; + + // Each thread computes one output element + index_t out_i = blockIdx.y * blockDim.y + threadIdx.y; + index_t out_j = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_i < out_h && out_j < out_w) { + DataType sum = 0; + + // Extract window and apply convolution + for (index_t ki = 0; ki < K; ++ki) { + for (index_t kj = 0; kj < K; ++kj) { + index_t in_i = out_i + ki; + index_t in_j = out_j + kj; + sum += input[in_i * W + in_j] * kernel[ki * K + kj]; + } + } + + output[out_i * out_w + out_j] = sum; + } + } + +This implementation directly follows the mathematical definition but has poor memory access patterns and limited parallelism within each output computation. + +Window Extraction with Tensor Descriptors +========================================= + +CK Tile's tensor descriptors provide an clean way to extract convolution windows: + +.. code-block:: cpp + + template + struct ConvolutionWindowExtractor { + static constexpr index_t OutH = H - K + 1; + static constexpr index_t OutW = W - K + 1; + + // Create tensor descriptor for all windows + using WindowsDescriptor = TensorDescriptor< + Sequence, + Sequence + >; + + __device__ void extract_all_windows( + const DataType* input, + DataType* windows_buffer) + { + WindowsDescriptor desc; + + // Extract all windows in parallel + index_t tid = threadIdx.x + blockIdx.x * blockDim.x; + index_t total_elements = OutH * OutW * K * K; + + for (index_t i = tid; i < total_elements; i += gridDim.x * blockDim.x) { + // Convert linear index to 4D coordinates + index_t tmp = i; + index_t kj = tmp % K; tmp /= K; + index_t ki = tmp % K; tmp /= K; + index_t out_j = tmp % OutW; tmp /= OutW; + index_t out_i = tmp; + + // Calculate source offset using descriptor + index_t src_offset = desc.calculate_offset({out_i, out_j, ki, kj}); + windows_buffer[i] = input[src_offset]; + } + } + }; + +The tensor descriptor automatically handles the complex indexing required for overlapping windows, making the code cleaner and less error-prone. + +Im2col Transformation +===================== + +The im2col transformation converts the 4D windows tensor into a 2D matrix suitable for matrix multiplication. This is where CK Tile's :ref:`transformation system ` shines: + +.. code-block:: cpp + + template + struct Im2colTransformer { + static constexpr index_t NumWindows = OutH * OutW; + static constexpr index_t PatchSize = K * K; + + // Step 1: Create 4D windows descriptor + using WindowsDescriptor = TensorDescriptor< + Sequence, + Sequence + >; + + // Step 2: Apply merge transforms to create 2D im2col layout + // See :ref:`ck_tile_transforms` for transform operations + using Im2colDescriptor = decltype( + transform_tensor_descriptor( + WindowsDescriptor{}, + make_tuple( + make_merge_transform(Sequence{}), // Merge spatial dims + make_merge_transform(Sequence{}) // Merge kernel dims + ), + Sequence<0, 1>{}, // Merge dimensions 0,1 + Sequence<2, 3>{} // Merge dimensions 2,3 + ) + ); + + __device__ void create_im2col_matrix( + const DataType* input, + DataType* im2col_matrix) + { + Im2colDescriptor desc; + + // Each thread handles multiple elements + index_t tid = threadIdx.x + blockIdx.x * blockDim.x; + index_t total_elements = NumWindows * PatchSize; + + for (index_t i = tid; i < total_elements; i += gridDim.x * blockDim.x) { + index_t window_idx = i / PatchSize; + index_t patch_idx = i % PatchSize; + + // Calculate source offset using merged descriptor + index_t src_offset = desc.calculate_offset({window_idx, patch_idx}); + im2col_matrix[i] = input[src_offset]; + } + } + }; + +The transformation pipeline: +1. Start with 4D tensor ``[OutH, OutW, K, K]`` +2. Merge spatial dimensions: ``[OutH, OutW] → NumWindows`` +3. Merge kernel dimensions: ``[K, K] → PatchSize`` +4. Result: 2D matrix ``[NumWindows, PatchSize]`` + +Optimized Convolution Kernel +============================ + +Combining all components into an optimized convolution implementation: + +.. code-block:: cpp + + template + __global__ void optimized_convolution_kernel( + const DataType* __restrict__ input, + const DataType* __restrict__ kernel, + DataType* __restrict__ output, + index_t H, index_t W, index_t K) + { + constexpr index_t WarpSize = 32; + const index_t OutH = H - K + 1; + const index_t OutW = W - K + 1; + const index_t NumWindows = OutH * OutW; + const index_t PatchSize = K * K; + + // Create im2col descriptor for this image size + using Im2colDesc = TensorDescriptor< + Sequence, + DynamicStrides // Computed based on H, W, K + >; + + // Tile distribution for matrix multiplication + // See :ref:`ck_tile_tile_distribution` for details + using ATileDist = TileDistribution< + Sequence, + Sequence + >; + using BTileDist = TileDistribution< + Sequence, + Sequence<1, BlockN> + >; + using CTileDist = TileDistribution< + Sequence, + Sequence + >; + + // Thread-local accumulator + // See :ref:`ck_tile_static_distributed_tensor` + StaticDistributedTensor c_accumulator; + + // Initialize accumulator + #pragma unroll + for (index_t i = 0; i < c_accumulator.size(); ++i) { + c_accumulator[i] = 0; + } + + // Main GEMM loop over K dimension + for (index_t k_tile = 0; k_tile < PatchSize; k_tile += TileK) { + // Create tile windows for im2col matrix and kernel + // See :ref:`ck_tile_tile_window` for window operations + auto a_window = make_tile_window( + input, Im2colDesc{H, W, K}, + {blockIdx.y * TileM, k_tile} + ); + + auto b_window = make_tile_window( + kernel, TensorDescriptor>{}, + {k_tile, 0} + ); + + // Load tiles - see :ref:`ck_tile_load_store_traits` for optimization + auto a_tile = a_window.load(); + auto b_tile = b_window.load(); + + // Synchronize after loads + __syncthreads(); + + // Local matrix multiplication + #pragma unroll + for (index_t m = 0; m < TileM/BlockM; ++m) { + #pragma unroll + for (index_t n = 0; n < TileN/BlockN; ++n) { + #pragma unroll + for (index_t k = 0; k < TileK; ++k) { + c_accumulator.at(m, n) += + a_tile.at(m, k) * b_tile.at(k, n); + } + } + } + } + + // Store results back to global memory + auto c_window = make_tile_window( + output, TensorDescriptor>{OutW, 1}, + {blockIdx.y * TileM, blockIdx.x * TileN} + ); + c_window.store(c_accumulator); + } + +Multi-Channel Convolution +========================= + +Real-world convolutions involve multiple input and output channels. CK Tile handles this cleanly: + +.. code-block:: cpp + + template + struct MultiChannelConvolution { + static constexpr index_t OutH = H - K + 1; + static constexpr index_t OutW = W - K + 1; + static constexpr index_t NumWindows = OutH * OutW; + static constexpr index_t PatchSize = K * K * CIn; + + // 5D windows descriptor [OutH, OutW, K, K, CIn] + using Windows5D = TensorDescriptor< + Sequence, + Sequence + >; + + // Im2col: [NumWindows, PatchSize] + using Im2colDesc = decltype( + transform_tensor_descriptor( + Windows5D{}, + make_tuple( + make_merge_transform(Sequence{}), + make_merge_transform(Sequence{}) + ), + Sequence<0, 1>{}, + Sequence<2, 3, 4>{} + ) + ); + + // Filter layout: [K*K*CIn, COut] + using FilterDesc = TensorDescriptor< + Sequence, + Sequence + >; + + __device__ void compute( + const DataType* input, // [H, W, CIn] + const DataType* filters, // [K, K, CIn, COut] + DataType* output) // [OutH, OutW, COut] + { + // The convolution becomes a matrix multiplication: + // [NumWindows, PatchSize] @ [PatchSize, COut] = [NumWindows, COut] + // Then reshape to [OutH, OutW, COut] + } + }; + +The multi-channel extension naturally follows from the single-channel case: + +- Input: ``[H, W, CIn]`` +- Filters: ``[K, K, CIn, COut]`` +- Im2col matrix: ``[NumWindows, K×K×CIn]`` +- Output: ``[OutH, OutW, COut]`` + +Performance Optimizations +========================= + +CK Tile enables several optimizations for convolution: + +**1. Memory Coalescing** + +.. code-block:: cpp + + // Coalesced access pattern for im2col + template + __device__ void load_im2col_vectorized( + const float* input, + float* im2col_tile, + const Im2colDescriptor& desc) + { + using VectorType = vector_type_t; + + // Load multiple elements per thread + index_t tid = threadIdx.x; + index_t stride = blockDim.x; + + for (index_t i = tid; i < NumElements; i += stride * VectorSize) { + VectorType vec = *reinterpret_cast(&input[i]); + *reinterpret_cast(&im2col_tile[i]) = vec; + } + } + +**2. Shared Memory Tiling** + +.. code-block:: cpp + + // Use shared memory for frequently accessed data + __shared__ float smem_a[TileM][TileK]; + __shared__ float smem_b[TileK][TileN]; + + // Collaborative loading with proper bank conflict avoidance + // See :ref:`ck_tile_lds_bank_conflicts` for optimization + auto load_tile_to_smem = [&](auto& window, float smem[][TileK]) { + #pragma unroll + for (index_t i = threadIdx.y; i < TileM; i += blockDim.y) { + #pragma unroll + for (index_t j = threadIdx.x; j < TileK; j += blockDim.x) { + smem[i][j] = window.at(i, j); + } + } + }; + +**3. Register Blocking** + +.. code-block:: cpp + + // Each thread computes multiple output elements + template + struct RegisterBlock { + float c_reg[RegM][RegN]; + + __device__ void compute(const float* a_smem, const float* b_smem) { + #pragma unroll + for (index_t k = 0; k < TileK; ++k) { + #pragma unroll + for (index_t m = 0; m < RegM; ++m) { + #pragma unroll + for (index_t n = 0; n < RegN; ++n) { + c_reg[m][n] += a_smem[m] * b_smem[n]; + } + } + } + } + }; + +Performance Characteristics +=========================== + +The tensor descriptor approach provides optimal performance characteristics: + +.. list-table:: Method Comparison + :header-rows: 1 + :widths: 25 20 20 20 15 + + * - Method + - Memory Usage + - Parallelization + - GPU Efficiency + - Flexibility + * - Naive loops + - Low + - Poor + - Poor + - High + * - Direct im2col copy + - High + - Excellent + - Good + - Medium + * - Tensor descriptors + - Medium + - Excellent + - Excellent + - High + * - CK Tile optimized + - Low + - Excellent + - Excellent + - High + +Key advantages of the CK Tile approach: + +1. **Zero-copy views**: Tensor descriptors create logical views without data duplication +2. **Compile-time optimization**: All indexing calculations resolve at compile time +3. **Hardware-aware**: Automatic alignment and vectorization based on :ref:`architecture ` +4. **Composability**: Complex access patterns built from simple :ref:`transformations ` +5. **Performance portability**: Same code optimizes differently for different GPUs + +Summary +======= + +This example demonstrates how CK Tile transforms convolution from a memory-bound operation with poor parallelism into a compute-bound operation that utilizes GPU resources. The key insights are: + +- **Sliding windows** can be efficiently represented using tensor descriptors with appropriate strides +- **Im2col transformation** converts convolution to matrix multiplication without data copies +- **Tile distribution** enables optimal work distribution across GPU threads (see :ref:`ck_tile_tile_distribution`) +- **Multi-channel support** extends naturally through higher-dimensional descriptors +- **Performance optimizations** like vectorization and shared memory are seamlessly integrated (see :ref:`ck_tile_gemm_optimization` for similar techniques) + +The tensor descriptor system provides a unified framework for these transformations, enabling automatic generation of efficient kernels for various convolution configurations and hardware architectures. This approach forms the foundation for production deep learning frameworks' convolution implementations. diff --git a/docs/conceptual/ck_tile/coordinate_movement.rst b/docs/conceptual/ck_tile/coordinate_movement.rst new file mode 100644 index 0000000000..73633afa88 --- /dev/null +++ b/docs/conceptual/ck_tile/coordinate_movement.rst @@ -0,0 +1,532 @@ +.. meta:: + :description: CK Tile advanced coordinate operations documentation + :keywords: CK Tile, coordinate movement, tensor coordinates, GPU programming + +.. _ck_tile_coordinate_movement: + +**************************** +Advanced Coordinate Movement +**************************** + +Overview +======== + +Advanced coordinate operations form the bridge between mathematical transformations and practical tensor manipulation in CK Tile. These operations enable efficient navigation through complex tensor layouts without recalculating entire transformation chains. Understanding coordinate movement is essential for implementing high-performance GPU kernels that traverse multi-dimensional data structures. + +The coordinate movement system provides two key abstractions: TensorCoordinate for descriptor-aware navigation and TensorAdaptorCoordinate for tracking positions through transformation chains. Together with movement functions, they enable advanced access patterns while maintaining optimal performance through incremental updates rather than full recalculation. + +For the mathematical foundations of coordinate systems, see :ref:`ck_tile_coordinate_systems`. For simpler coordinate concepts, see :ref:`ck_tile_tensor_coordinates`. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Coordinate Movement System" + TC["TensorCoordinate
Position + Descriptor Context"] + TAC["TensorAdaptorCoordinate
Position + Transform Context"] + MC["move_coordinate()
Efficient Navigation"] + end + + subgraph "Movement Example" + S["Start: [1,1]
Offset: 5"] + M1["Move [0,1]
→ [1,2]
Offset: 6"] + M2["Move [1,0]
→ [2,2]
Offset: 10"] + M3["Move [1,1]
→ [3,3]
Offset: 15"] + end + + TC --> MC + TAC --> MC + + S --> M1 + M1 --> M2 + M2 --> M3 + + style TC fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style TAC fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style MC fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_movement.svg + :alt: Diagram + :align: center + +.. image:: diagrams/coordinate_movement.svg + :alt: Diagram + :align: center + +TensorCoordinate: Descriptor-Aware Navigation +============================================= + +TensorCoordinate combines a multi-dimensional position with descriptor context to provide efficient offset calculation and validation. It caches transformation results to avoid redundant computations during navigation. This builds on the :ref:`ck_tile_descriptors` concepts for tensor specifications. + +Basic Structure +--------------- + +.. code-block:: cpp + + template + class TensorCoordinate { + private: + MultiIndex top_index_; // Position in top dimensions + MultiIndex hidden_index_; // Cached transformation results + index_t offset_; // Cached linear offset + + public: + // Create coordinate from descriptor and position + __host__ __device__ TensorCoordinate( + const TensorDescriptor& desc, + const MultiIndex& top_index) + { + top_index_ = top_index; + // Apply descriptor transforms to compute hidden indices + hidden_index_ = desc.calculate_bottom_index(top_index); + offset_ = desc.calculate_offset(top_index); + } + + // Access methods + __host__ __device__ const MultiIndex& get_index() const { + return top_index_; + } + + __host__ __device__ index_t get_offset() const { + return offset_; + } + + __host__ __device__ index_t ndim_hidden() const { + return hidden_index_.size(); + } + }; + +Creating and Using TensorCoordinate +----------------------------------- + +.. code-block:: cpp + + // Example: Navigate a 4x3 matrix with custom strides + template + __device__ void demonstrate_tensor_coordinate() { + // Create descriptor for 4x3 matrix, row-major layout + using Desc = TensorDescriptor< + Sequence<4, 3>, // Shape + Sequence<3, 1> // Strides + >; + Desc desc; + + // Create coordinate at position [2, 1] + auto coord = make_tensor_coordinate(desc, make_multi_index(2, 1)); + + // Access coordinate information + auto position = coord.get_index(); // [2, 1] + auto offset = coord.get_offset(); // 2*3 + 1 = 7 + auto hidden_dims = coord.ndim_hidden(); // 0 (no hidden dims) + + // Use offset for memory access + DataType* tensor_data = ...; + DataType value = tensor_data[offset]; + } + +Key Benefits +------------ + +1. **Context Preservation**: The coordinate maintains descriptor context for validation +2. **Cached Calculations**: Transformation results are cached for efficiency +3. **Type Safety**: Compile-time checking ensures coordinate-descriptor compatibility +4. **Zero Overhead**: All operations resolve at compile time when possible + + +TensorAdaptorCoordinate: Transform-Aware Tracking +================================================== + +TensorAdaptorCoordinate extends the concept to track coordinates through transformation chains, maintaining both input (top) and output (bottom) positions. This leverages :ref:`ck_tile_adaptors` and :ref:`ck_tile_transforms` for complex coordinate mappings. + +Structure and Implementation +---------------------------- + +.. code-block:: cpp + + template + class TensorAdaptorCoordinate { + private: + MultiIndex top_index_; // Input position + MultiIndex bottom_index_; // Output after transformations + MultiIndex hidden_index_; // Intermediate results + + public: + // Create from adaptor and position + __host__ __device__ TensorAdaptorCoordinate( + const TensorAdaptor& adaptor, + const MultiIndex& top_index) + { + top_index_ = top_index; + // Apply adaptor transforms + bottom_index_ = adaptor.calculate_bottom_index(top_index); + // Cache intermediate results + hidden_index_ = adaptor.get_hidden_index(top_index); + } + + // Access transformed coordinates + __host__ __device__ const MultiIndex& get_top_index() const { + return top_index_; + } + + __host__ __device__ const MultiIndex& get_bottom_index() const { + return bottom_index_; + } + }; + +Tracking Through Transformations +-------------------------------- + +.. code-block:: cpp + + // Example: Track coordinates through transpose + template + __device__ void demonstrate_adaptor_coordinate() { + // Create transpose adaptor (swap dimensions) + auto adaptor = make_transpose_adaptor<2>(Sequence<1, 0>{}); + + // Create coordinate at [2, 3] + auto coord = make_tensor_adaptor_coordinate( + adaptor, + make_multi_index(2, 3) + ); + + // Track transformation + auto input_pos = coord.get_top_index(); // [2, 3] + auto output_pos = coord.get_bottom_index(); // [3, 2] (swapped) + + // Use for complex access patterns + DataType* src_data = ...; + DataType* dst_data = ...; + + // Read from transposed position + index_t src_offset = calculate_offset(output_pos); + DataType value = src_data[src_offset]; + } + +Efficient Coordinate Movement +============================= + +The ``move_tensor_coordinate`` function provides efficient navigation by updating coordinates incrementally rather than recreating them. + +Basic Movement Operations +------------------------- + +.. code-block:: cpp + + // Move tensor coordinate through descriptor + template + __host__ __device__ void move_tensor_coordinate( + const TensorDescriptor& desc, + TensorCoordinate& coord, + const MultiIndex& step) + { + // Update top index + coord.top_index_ += step; + + // Incrementally update cached values + // Only recalculate affected transformations + if (transformation_affects_movement(desc, step)) { + coord.hidden_index_ = desc.calculate_bottom_index(coord.top_index_); + coord.offset_ = desc.calculate_offset(coord.top_index_); + } else { + // Fast path: simple offset update + coord.offset_ += calculate_step_offset(desc, step); + } + } + +Practical Movement Patterns +--------------------------- + +.. code-block:: cpp + + // Example: Efficient matrix traversal + template + __global__ void matrix_traversal_kernel( + const DataType* input, + DataType* output, + index_t rows, index_t cols) + { + // Create descriptor for matrix + using Desc = TensorDescriptor; + Desc desc(make_tuple(rows, cols), make_tuple(cols, 1)); + + // Start at thread's assigned position + index_t start_row = blockIdx.y * blockDim.y + threadIdx.y; + index_t start_col = blockIdx.x * blockDim.x + threadIdx.x; + + auto coord = make_tensor_coordinate( + desc, + make_multi_index(start_row, start_col) + ); + + // Row-wise traversal pattern + for (index_t i = 0; i < 4; ++i) { + if (coord.get_index()[0] < rows) { + // Process current position + output[coord.get_offset()] = + process_value(input[coord.get_offset()]); + + // Move to next column + move_tensor_coordinate(desc, coord, make_multi_index(0, 1)); + + // Wrap to next row if needed + if (coord.get_index()[1] >= cols) { + move_tensor_coordinate( + desc, coord, + make_multi_index(1, -cols) + ); + } + } + } + } + +Movement Through Adaptors +------------------------- + +.. code-block:: cpp + + // Move through adaptor transformations + template + __host__ __device__ MultiIndex move_tensor_adaptor_coordinate( + const TensorAdaptor& adaptor, + TensorAdaptorCoordinate& coord, + const MultiIndex& step) + { + // Update top index + MultiIndex old_top = coord.top_index_; + coord.top_index_ += step; + + // Calculate new bottom index + MultiIndex old_bottom = coord.bottom_index_; + coord.bottom_index_ = adaptor.calculate_bottom_index(coord.top_index_); + + // Return the change in bottom coordinates + return coord.bottom_index_ - old_bottom; + } + +Advanced Movement Patterns +========================== + +Real-world applications use advanced movement patterns for optimal memory access. These patterns often relate to :ref:`ck_tile_tile_window` operations and :ref:`ck_tile_tile_distribution` concepts: + +Tiled Access Pattern +-------------------- + +.. code-block:: cpp + + template + __device__ void tiled_movement_pattern( + const float* input, + float* output, + index_t M, index_t N) + { + // Descriptor for full matrix + using MatrixDesc = TensorDescriptor< + DynamicSequence, + DynamicSequence + >; + MatrixDesc desc(make_tuple(M, N), make_tuple(N, 1)); + + // Start at tile corner + index_t tile_row = blockIdx.y * TileM; + index_t tile_col = blockIdx.x * TileN; + + auto coord = make_tensor_coordinate( + desc, + make_multi_index(tile_row, tile_col) + ); + + // Process tile with efficient movement + #pragma unroll + for (index_t i = 0; i < TileM; ++i) { + #pragma unroll + for (index_t j = 0; j < TileN; ++j) { + if (i == 0 && j == 0) { + // First element - already positioned + } else if (j == 0) { + // New row - move down and back to start column + move_tensor_coordinate( + desc, coord, + make_multi_index(1, -(TileN-1)) + ); + } else { + // Same row - move right + move_tensor_coordinate( + desc, coord, + make_multi_index(0, 1) + ); + } + + // Process element + output[coord.get_offset()] = + compute_value(input[coord.get_offset()]); + } + } + } + +Space-Filling Curve Movement +---------------------------- + +For more details on space-filling curves and their benefits, see :ref:`ck_tile_space_filling_curve`. + +.. code-block:: cpp + + // Snake pattern for optimal cache usage + template + __device__ void snake_pattern_movement( + const float* input, + float* output, + index_t M, index_t N) + { + using Desc = TensorDescriptor; + Desc desc(make_tuple(M, N), make_tuple(N, 1)); + + auto coord = make_tensor_coordinate( + desc, + make_multi_index(threadIdx.y, threadIdx.x) + ); + + // Snake through block + for (index_t row = 0; row < BlockSize; ++row) { + for (index_t col = 0; col < BlockSize; ++col) { + // Process current position + process_element(input, output, coord.get_offset()); + + // Snake movement pattern + if (row % 2 == 0) { + // Even rows: move right + if (col < BlockSize - 1) { + move_tensor_coordinate( + desc, coord, make_multi_index(0, 1) + ); + } + } else { + // Odd rows: move left + if (col < BlockSize - 1) { + move_tensor_coordinate( + desc, coord, make_multi_index(0, -1) + ); + } + } + } + + // Move to next row + if (row < BlockSize - 1) { + move_tensor_coordinate( + desc, coord, make_multi_index(1, 0) + ); + } + } + } + +Performance Considerations +=================================== + +Efficient coordinate movement is critical for GPU performance. See :ref:`ck_tile_gpu_basics` for hardware details. + +**1. Incremental Updates** + +.. code-block:: cpp + + // Inefficient: recreate coordinate + for (index_t i = 0; i < N; ++i) { + auto coord = make_tensor_coordinate(desc, make_multi_index(i, j)); + process(data[coord.get_offset()]); + } + + // Efficient: incremental movement + auto coord = make_tensor_coordinate(desc, make_multi_index(0, j)); + for (index_t i = 0; i < N; ++i) { + process(data[coord.get_offset()]); + move_tensor_coordinate(desc, coord, make_multi_index(1, 0)); + } + +**2. Movement Caching** + +.. code-block:: cpp + + // Cache frequently used movements + template + struct MovementCache { + MultiIndex row_step = make_multi_index(1, 0); + MultiIndex col_step = make_multi_index(0, 1); + MultiIndex diag_step = make_multi_index(1, 1); + + __device__ void move_row(auto& coord) { + move_tensor_coordinate(Desc{}, coord, row_step); + } + }; + +**3. Vectorized Movement** + +.. code-block:: cpp + + // Move multiple coordinates simultaneously + template + __device__ void vectorized_movement( + TensorCoordinate coords[NumCoords], + const MultiIndex& step) + { + #pragma unroll + for (index_t i = 0; i < NumCoords; ++i) { + move_tensor_coordinate(Desc{}, coords[i], step); + } + } + +Integration with CK Tile Components +=================================== + +Coordinate movement integrates seamlessly with other CK Tile components: + +.. code-block:: cpp + + // Example: Tile window with coordinate movement + template + __device__ void process_tile_with_movement( + TileWindow& window, + index_t tile_size) + { + // Create coordinate for tile traversal + auto coord = window.get_tile_coordinate(); + + // Process tile elements with movement + for (index_t i = 0; i < tile_size; ++i) { + for (index_t j = 0; j < tile_size; ++j) { + // Load using coordinate + auto value = window.load_at(coord); + + // Process value + auto result = compute(value); + + // Store result + window.store_at(coord, result); + + // Move to next element + window.move_coordinate(coord, {0, 1}); + } + // Move to next row + window.move_coordinate(coord, {1, -tile_size}); + } + } + + +Advanced coordinate operations provide the foundation for efficient tensor navigation in CK Tile: + +- **TensorCoordinate**: Combines position with descriptor context for validated navigation +- **TensorAdaptorCoordinate**: Tracks coordinates through transformation chains +- **move_tensor_coordinate**: Enables efficient incremental updates without recalculation +- **Movement Patterns**: Support advanced access patterns like tiling and space-filling curves +- **Performance**: Incremental updates are orders of magnitude faster than coordinate recreation +- **Integration**: Seamlessly works with tile windows, distributions, and other CK Tile components + +These operations are essential for implementing high-performance GPU kernels that can navigate complex tensor layouts efficiently. By understanding and utilizing coordinate movement, kernels can be created that achieve optimal memory access patterns while maintaining code clarity and correctness. diff --git a/docs/conceptual/ck_tile/coordinate_systems.rst b/docs/conceptual/ck_tile/coordinate_systems.rst new file mode 100644 index 0000000000..13a9619010 --- /dev/null +++ b/docs/conceptual/ck_tile/coordinate_systems.rst @@ -0,0 +1,612 @@ +.. _ck_tile_coordinate_systems: + +Coordinate Systems - The Mathematical Foundation +================================================ + +Overview +-------- + +At the heart of the Composable Kernel framework lies a mathematical foundation based on coordinate transformations. This foundation enables the automatic generation of optimal memory access patterns while maintaining a clear separation between algorithmic intent and hardware implementation details. The coordinate system framework transforms the task of GPU work distribution into a series of well-defined mathematical transformations. + +These coordinate systems provide the mathematical machinery that maps abstract thread identities to concrete memory addresses, ensuring that every memory access is optimized for the underlying hardware. This systematic approach eliminates the error-prone manual calculations that plague traditional GPU programming while enabling optimizations that would be impractical to implement by hand. + +The Five Coordinate Spaces +-------------------------- + +The CK framework employs five interconnected coordinate spaces, each serving a specific purpose in the journey from thread identification to memory access. These spaces work together to solve the fundamental challenge of GPU programming: efficiently distributing work across thousands of parallel threads while maintaining optimal memory access patterns. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Coordinate Spaces Overview" + P["P-space
Thread Identification
Which thread am I?"] + Y["Y-space
Logical Tile
Which element in my tile?"] + X["X-space
Physical Tensor
Where in the tensor?"] + R["R-space
Replication
Data sharing pattern"] + D["D-space
Linear Storage
Memory address"] + end + + subgraph "Transformations" + T1["P + Y → X
Thread + Element → Position"] + T2["X → D
Position → Address"] + end + + P --> T1 + Y --> T1 + T1 --> X + X --> T2 + T2 --> D + + R -.-> P + R -.-> Y + + style P fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style Y fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style R fill:#fce4ec,stroke:#c2185b,stroke-width:2px + style D fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_systems_1.svg + :alt: Diagram + :align: center + +The Challenge and Solution +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Consider a fundamental scenario: an 8×8 matrix and 4 GPU threads. Each thread needs to answer several critical questions: + +1. **Which thread am I?** (Thread identification) +2. **What work should I do?** (Work assignment) +3. **Where is my data in the tensor?** (Physical location) +4. **How do I share data with other threads?** (Cooperation) +5. **What's the memory address?** (Hardware access) + +The coordinate system framework provides a systematic solution through five specialized spaces that transform from logical concepts to physical reality. Each space captures a different aspect of the computation, and the transformations between them encode the distribution strategy. + +Thread Identification +------------------------------ + +Partition Space (P-space) represents the foundation of the coordinate system hierarchy. This space captures the identity of each processing element within the GPU's execution model, providing a structured way to identify threads across the complex hierarchy of warps, blocks, and grids. + +GPU Thread Hierarchy +~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "GPU Thread Hierarchy" + subgraph "Block" + subgraph "Warp 0" + T0["Thread 0
P=[0,0]"] + T1["Thread 1
P=[0,1]"] + T2["Thread 2
P=[0,2]"] + T31["..."] + T3["Thread 31
P=[0,31]"] + end + subgraph "Warp 1" + T32["Thread 32
P=[1,0]"] + T33["Thread 33
P=[1,1]"] + T34["..."] + T63["Thread 63
P=[1,31]"] + end + W2["Warp 2..."] + W7["Warp 7"] + end + end + + subgraph "P-space Mapping" + PM["P-coordinates = [warp_id, lane_id]
or
P-coordinates = [block_x, block_y, thread_x, thread_y]"] + end + + T0 --> PM + T32 --> PM + + style T0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style T32 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_systems_2.svg + :alt: Diagram + :align: center + +The structure of P-space directly reflects the :ref:`hardware organization ` of GPUs. Each thread receives a unique P-coordinate that encodes its position within the execution hierarchy. For simple distributions, P-space might be one-dimensional, containing only a thread ID. For complex hierarchical distributions, P-space can have multiple dimensions representing different levels of the GPU's thread organization. + +C++ Implementation +~~~~~~~~~~~~~~~~~~ + +**File**: ``include/ck_tile/core/container/multi_index.hpp`` + +.. code-block:: cpp + + #include + #include + + template + __device__ void example_p_space_calculation() + { + // Get P-coordinates from hardware thread IDs + const index_t thread_id = get_thread_local_1d_id(); + const index_t warp_id = get_warp_local_1d_id(); + const index_t lane_id = get_lane_id(); + + // Convert to multi-dimensional P-coordinates + auto p_coord_2d = make_multi_index(warp_id, lane_id); + + // Using tile distribution (preferred method) + constexpr auto tile_distribution = TileDistribution{}; + const auto p_coord = tile_distribution.calculate_p_coord(); + + // P-coordinates determine: + // 1. Work distribution - which data this thread processes + // 2. Memory coalescing - ensuring optimal access patterns + // 3. Thread cooperation - coordinating shared memory usage + } + +The P-space abstraction enables CK to handle different GPU architectures transparently. Whether running on GPUs with 32-thread warps or 64-thread wavefronts, the P-space coordinates provide a consistent interface while the underlying implementation adapts to the hardware. + +Logical Work Organization +---------------------------------- + +Yield Space (Y-space) represents the logical organization of work within each thread's assigned tile. While P-space identifies which thread is executing, Y-space defines what that thread does with its assigned work. This abstraction enables the expression of complex access patterns in a hardware-independent manner. + +Work Assignment Structure +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Thread's Tile (2x2 elements)" + Y00["Y=[0,0]
Element 0"] + Y01["Y=[0,1]
Element 1"] + Y10["Y=[1,0]
Element 2"] + Y11["Y=[1,1]
Element 3"] + end + + subgraph "Y-space Structure" + YS["Each thread processes
the same Y-space pattern
but at different X locations"] + end + + subgraph "Example: 4 Threads" + T0["Thread 0
P=[0,0]"] + T1["Thread 1
P=[0,1]"] + T2["Thread 2
P=[1,0]"] + T3["Thread 3
P=[1,1]"] + end + + Y00 --> YS + Y01 --> YS + Y10 --> YS + Y11 --> YS + + T0 --> YS + T1 --> YS + T2 --> YS + T3 --> YS + + style Y00 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style Y01 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style Y10 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style Y11 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_systems_3.svg + :alt: Diagram + :align: center + +The power of Y-space lies in its ability to express different iteration patterns without changing the underlying distribution logic. A thread might traverse its Y-space in row-major order for one algorithm, column-major for another, or even use :ref:`space-filling curves ` for optimal cache utilization. This flexibility enables algorithm-specific optimizations while maintaining a consistent framework. + +Hierarchical Y-Space +~~~~~~~~~~~~~~~~~~~~ + +For complex kernels, Y-space can have a hierarchical structure that mirrors the hierarchical nature of GPU architectures: + +.. code-block:: cpp + + // Hierarchical Y-space for complex kernels + template + __device__ void example_hierarchical_y_space() + { + constexpr auto tile_distribution = TileDistribution{}; + + // 4D Y-space: [repeat, warp, thread, vector] + constexpr auto y_hierarchical = make_tuple( + number<4>{}, // Repeat dimension + number<2>{}, // Warp dimension + number<8>{}, // Thread dimension + number<4>{} // Vector dimension + ); + + // Each dimension serves different purpose: + // - Repeat: Algorithm repetition (e.g., attention heads) + // - Warp: Inter-warp cooperation patterns + // - Thread: Per-thread work items + // - Vector: SIMD vectorization + + // Sweep through Y-space with compile-time unrolling + sweep_tile(distributed_tensor, [&](auto y_coord) { + // y_coord is compile-time multi_index + // All iterations unrolled at compile time + auto value = distributed_tensor(y_coord); + // Process value... + }); + } + +Physical Tensor Coordinates +------------------------------------ + +X-space represents the ground truth of data organization: the actual coordinates within the global tensor. This space directly corresponds to how users conceptualize their data: row and column indices for matrices, spatial coordinates for images, or multi-dimensional indices for general tensors. + +Memory Layout Mapping +~~~~~~~~~~~~~~~~~~~~~ + +The relationship between X-space and physical memory involves considerations of data layout, padding, and alignment: + +.. code-block:: cpp + + template + __device__ void example_x_space_operations() + { + constexpr auto tensor_desc = TensorDescriptor{}; + + // X-space properties + constexpr auto x_lengths = tensor_desc.get_lengths(); + constexpr auto x_strides = tensor_desc.get_strides(); + + // Direct X-coordinate specification + constexpr auto x_coord = make_multi_index(number<3>{}, number<4>{}); + + // Convert to linear offset + constexpr auto linear_offset = tensor_desc.calculate_offset(x_coord); + + // X-coordinates from P+Y transformation + const auto x_from_py = tile_dist.calculate_index(p_coord, y_coord); + + // Bounds checking + const bool valid = is_valid_x_coord(x_coord, x_lengths); + } + +The Core Transformation: P + Y → X +---------------------------------- + +The transformation from P and Y coordinates to X coordinates represents the heart of tile distribution. This transformation encodes the entire distribution strategy, determining how logical thread work maps to physical tensor locations. + +Transformation Pipeline +~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Input" + P["P-coordinates
Thread identity
P=[1,0]"] + Y["Y-coordinates
Element in tile
Y=[0,1]"] + end + + subgraph "Transformation" + T["P + Y → X
Base position + Offset"] + end + + subgraph "Output" + X["X-coordinates
Tensor position
X=[2,1]"] + end + + subgraph "Example" + E["Thread P=[1,0] at base (2,0)
Element Y=[0,1] adds offset (0,1)
Result X=[2,1] in tensor"] + end + + P --> T + Y --> T + T --> X + X --> E + + style P fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style Y fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_systems_4.svg + :alt: Diagram + :align: center + +Mathematical Foundation +~~~~~~~~~~~~~~~~~~~~~~~ + +The P+Y→X transformation can be expressed mathematically as a composition of functions: + +.. math:: + + X = f(P, Y) = BasePosition(P) + LocalOffset(Y) + +Where: +- BasePosition(P) determines where in the tensor this thread's tile begins +- LocalOffset(Y) specifies the offset within the tile + +This transformation is highly configurable through the distribution encoding, enabling different strategies for different algorithms while maintaining the same mathematical framework. + +Replication and Cooperation +------------------------------------ + +Replication Space (R-space) introduces a mechanism for expressing data sharing and cooperation patterns between threads. Unlike the other coordinate spaces which map to unique data elements, R-space enables multiple processing elements to work on the same data, facilitating communication and reduction operations. + +Replication Patterns +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + template + __device__ void example_r_space_operations() + { + constexpr auto tile_distribution = TileDistribution{}; + constexpr auto r_lengths = tile_distribution.get_r_lengths(); + + // Broadcasting with R-space + template + __device__ auto broadcast_across_r_space(DataType value) + { + const auto r_coord = tile_distribution.calculate_r_coord(); + __shared__ DataType shared_value; + + if (r_coord == make_multi_index(0, 0)) { + shared_value = value; // Source thread + } + __syncthreads(); + + return shared_value; // All threads get the value + } + + // Reduction across R-space + template + __device__ auto reduce_across_r_space(DataType local_value) + { + // Use hardware-accelerated reduction + return block_reduce_sum(local_value); + } + } + +R-space enables cooperation patterns that would be difficult to express otherwise. By providing a systematic way to identify which threads share data, it enables automatic generation of communication patterns. + +Memory Linearization +----------------------------- + +D-space represents the final transformation in the coordinate pipeline: converting multi-dimensional coordinates to linear memory addresses. This transformation incorporates all the low-level details of memory layout, including stride patterns, padding, and alignment requirements. + +Linearization Strategies +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "X-coordinates" + X["X = [2, 3]
2D Position"] + end + + subgraph "Layout Options" + RM["Row-Major
D = 2×width + 3"] + CM["Column-Major
D = 3×height + 2"] + BL["Blocked
Complex pattern"] + end + + subgraph "D-coordinate" + D["D = 11
Linear Address"] + end + + X --> RM + X --> CM + X --> BL + RM --> D + + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style D fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + + + + + + +.. image:: diagrams/coordinate_systems_5.svg + :alt: Diagram + :align: center + +The linearization process must consider multiple factors: + +.. code-block:: cpp + + template + __device__ void example_d_space_linearization() + { + // Standard linearization + template + __device__ constexpr auto calculate_linear_offset(const XCoord& x_coord) + { + index_t offset = 0; + static_for<0, ndim, 1>{}([&](auto dim) { + offset += x_coord.at(dim) * strides.at(dim); + }); + return offset; + } + + // Specialized patterns for optimization + // Row-major: offset = x0 * N + x1 + // Column-major: offset = x1 * M + x0 + // Blocked: Complex pattern for cache efficiency + } + +Complete Pipeline Example +------------------------- + +The following is a complete example showing how all coordinate spaces work together: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Step 1: Thread Identification" + TID["Thread ID = 5"] + P["P-coordinates
P = [0, 5]
(warp 0, lane 5)"] + end + + subgraph "Step 2: Work Assignment" + Y["Y-coordinates
Y = [1, 0]
(element in tile)"] + end + + subgraph "Step 3: P+Y Transformation" + TRANS["P + Y → X
Thread position + Element offset"] + X["X-coordinates
X = [1, 5]
(tensor position)"] + end + + subgraph "Step 4: Linearization" + LIN["X → D
Row-major: D = x₀ × width + x₁"] + D["D-coordinate
D = 13
(memory address)"] + end + + subgraph "Step 5: Memory Access" + MEM["Hardware accesses
memory[13]"] + end + + TID --> P + P --> TRANS + Y --> TRANS + TRANS --> X + X --> LIN + LIN --> D + D --> MEM + + style P fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style Y fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:3px + style D fill:#f3e5f5,stroke:#7b1fa2,stroke-width:3px + style MEM fill:#ffebee,stroke:#c62828,stroke-width:3px + + + + +.. image:: diagrams/coordinate_systems_6.svg + :alt: Diagram + :align: center + +Real-World Example: Matrix Multiplication +----------------------------------------- + +:ref:`matrix multiplication ` demonstrates how coordinate systems work in practice/ + +.. code-block:: cpp + + template + __global__ void gemm_kernel_with_coordinates( + const AType* a_ptr, const BType* b_ptr, CType* c_ptr, + index_t M, index_t N, index_t K) + { + // Define distribution encoding + using Encoding = tile_distribution_encoding< + sequence<>, // R: no replication + tuple, // H for M dimension + sequence<4, 2, 8, 4>>, // H for N dimension + tuple, sequence<1, 2>>, // P mappings + tuple, sequence<2, 2>>, // P minor + sequence<1, 1, 2, 2>, // Y major + sequence<0, 3, 0, 3> // Y minor + >; + + constexpr auto distribution = make_static_tile_distribution(Encoding{}); + + // Step 1: Get P-coordinates (thread identity) + const auto p_coord = distribution.calculate_p_coord(); + + // Step 2: Iterate through Y-space (work assignment) + sweep_tile(c_tile, [&](auto y_coord) { + // Step 3: P+Y→X transformation + const auto x_coord = distribution.calculate_index(p_coord, y_coord); + + // Step 4: X→D transformation (handled by tensor view) + // Step 5: Actual computation at these coordinates + c_tile(y_coord) = compute_element(x_coord); + }); + } + +Performance Implications +------------------------ + +The coordinate system framework enables several critical optimizations: + +**Memory Coalescing**: By carefully structuring the P+Y→X transformation, consecutive threads access consecutive memory locations, achieving optimal memory bandwidth utilization. + +**Cache Efficiency**: The Y-space traversal order can be designed to maximize cache reuse, keeping frequently accessed data in fast memory. + +**Register Optimization**: The Y→D transformation enables optimal register allocation, minimizing register pressure while maximizing reuse. + +**Vectorization**: The coordinate transformations naturally align with vector operations, enabling efficient use of SIMD instructions. + +Summary +------- + +The coordinate system framework represents the mathematical foundation that enables CK's high performance and productivity benefits. Through the systematic transformation from thread identity (P-space) through logical work organization (Y-space) to physical tensor coordinates (X-space) and finally to linear memory addresses (D-space), this framework solves the fundamental challenges of GPU programming. + +Key insights from the coordinate system framework: + +**Separation of Concerns**: Each coordinate space captures a different aspect of the computation, enabling independent optimization of each aspect while maintaining a coherent whole. + +**Mathematical Rigor**: The transformations between coordinate spaces are well-defined mathematical functions, enabling formal analysis and verification of distribution strategies. + +**Hardware Abstraction**: The framework abstracts hardware details while enabling hardware-specific optimizations, achieving both portability and performance. + +**Automatic Optimization**: By encoding distribution strategies as coordinate transformations, the framework enables automatic generation of optimal access patterns that would be impractical to implement manually. + +**Composability**: Different distribution strategies can be expressed by composing different transformations, enabling rapid experimentation and optimization. + +These coordinate systems provide the conceptual framework for reasoning about GPU computation and the practical tools for achieving optimal performance. As GPU architectures continue to evolve, this mathematical foundation ensures that CK programs can adapt and continue to achieve high performance. + +Next Steps +---------- + +With a solid understanding of the coordinate system framework, the next sections explore how these concepts are applied in practice. Return to :ref:`ck_tile_index` to see the structure of the complete CK Tile documentation. diff --git a/docs/conceptual/ck_tile/descriptors.rst b/docs/conceptual/ck_tile/descriptors.rst new file mode 100644 index 0000000000..3a52097d06 --- /dev/null +++ b/docs/conceptual/ck_tile/descriptors.rst @@ -0,0 +1,383 @@ +.. _ck_tile_descriptors: + +Tensor Descriptors - Complete Tensor Specifications +=================================================== + +Overview +-------- + +A TensorDescriptor is the complete blueprint for a tensor. It combines a shape, stride information, and a series of :ref:`transformations ` into a single object that defines exactly how a tensor's data is laid out in memory. This specification enables CK Tile to create complex tensor views without any data movement. + +In CK Tile, TensorDescriptors serve as the foundation for all tensor operations, providing: + +- **Memory Layout Specification**: How data is arranged in physical memory +- **Logical View Definition**: How the tensor appears to the programmer +- **Transformation Pipeline**: A series of :ref:`coordinate transformations ` +- **Zero-Copy Views**: Different logical representations of the same data, building on :ref:`BufferViews ` and :ref:`TensorViews ` + +Creating Basic Tensor Layouts +----------------------------- + +CK Tile provides several ways to create tensor descriptors for common memory layouts. + +Custom Strides +~~~~~~~~~~~~~~ + +The most fundamental way to define a tensor is with custom strides. This provides full control over how many elements to "jump" in memory to move to the next item along each dimension. This is particularly useful for creating padded layouts required by GPU algorithms. + +.. code-block:: cpp + + using namespace ck_tile; + + // Create a 3x4 tensor, but make each row take up 8 elements in memory + // (4 for data, 4 for padding) + constexpr auto M = 3; + constexpr auto N = 4; + constexpr auto RowStride = 8; // Padded stride + + auto descriptor = make_naive_tensor_descriptor( + make_tuple(M, N), // Shape: [3, 4] + make_tuple(RowStride, 1) // Strides: [8, 1] + ); + + // The total memory needed is 3 rows * 8 elements/row = 24 + constexpr auto element_space_size = M * RowStride; + + // Calculate offset of the element at [row=1, col=2] + multi_index<2> coord{1, 2}; + auto offset = descriptor.calculate_offset(coord); + // offset = 1*8 + 2*1 = 10 + +Packed Row-Major Layout +~~~~~~~~~~~~~~~~~~~~~~~~~ + +For most cases, a tightly packed, row-major layout is sufficient. The strides are calculated automatically, leaving no unused space between elements. + +.. code-block:: cpp + + using namespace ck_tile; + + // Create a packed 3x4 tensor + auto descriptor_packed = make_naive_tensor_descriptor_packed( + make_tuple(3, 4) + ); + + // Total memory is 3 * 4 = 12 elements + // Strides are automatically [4, 1] for row-major layout + + // Calculate offset of the element at [row=1, col=2] + multi_index<2> coord{1, 2}; + auto offset = descriptor_packed.calculate_offset(coord); + // offset = 1*4 + 2*1 = 6 + +Aligned Layout +~~~~~~~~~~~~~~ + +For GPU performance, memory layouts often need to be aligned. This function creates a row-major layout but ensures that each row's starting address is a multiple of a given alignment value, adding padding if necessary. + +.. code-block:: cpp + + using namespace ck_tile; + + // Create a 4x5 tensor with 8-element alignment + constexpr auto align = 8; // Align each row to 8-element boundary + + auto descriptor_aligned = make_naive_tensor_descriptor_aligned( + make_tuple(4, 5), + align + ); + + // Without alignment, size would be 4*5=20 + // With alignment, the row stride becomes 8 (smallest multiple of 8 >= 5) + // Total size = 4 rows * 8 elements/row = 32 + +The Pipeline Concept +-------------------- + +Every TensorDescriptor in CK Tile can be thought of as a **transformation pipeline**. The functions above create the *first stage* of this pipeline, defining the initial :ref:`transformation ` that takes a simple, one-dimensional block of memory and presents it as a logical, multi-dimensional tensor view. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Pipeline Stages" + S1["Stage 1
Base Layout
[M, N]"] + S2["Stage 2
Transform
Unmerge"] + S3["Stage 3
New View
[M1, M2, N]"] + S4["Stage N
Final View
[...]"] + end + + subgraph "Same Data" + D["Physical Memory
No data movement"] + end + + S1 --> S2 + S2 --> S3 + S3 --> S4 + + S1 -.-> D + S2 -.-> D + S3 -.-> D + S4 -.-> D + + style D fill:#ffebee,stroke:#d32f2f,stroke-width:2px + style S1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style S3 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + +.. image:: diagrams/descriptors_1.svg + :alt: Diagram + :align: center + +.. image:: diagrams/descriptors_1.svg + :alt: Diagram + :align: center + +The Initial Pipeline Stage +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A simple packed descriptor sets up a pipeline with a single transform: + +- **Input**: The raw, one-dimensional memory buffer (hidden dimension ID 0) +- **Output**: The logical dimensions that you interact with (hidden dimension IDs 1, 2, ...) + +This initial stage converts linear memory addresses into multi-dimensional coordinates. See :ref:`ck_tile_adaptors` for how transforms chain together. + +Advanced Layouts: Step-by-Step Transformation +--------------------------------------------- + +The ``transform_tensor_descriptor`` function adds new stages to an existing descriptor's pipeline using :ref:`transforms `. + +Transform a [2, 6] Tensor into a [2, 2, 3] View +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This example reinterprets a 2D tensor with shape [2, 6] as a 3D tensor with shape [2, 2, 3], without changing the underlying 12-element memory buffer. + +**Step 1: Define the Base Descriptor** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create the [2, 6] base descriptor + auto base_descriptor = make_naive_tensor_descriptor_packed( + make_tuple(2, 6) + ); + + // This creates an initial pipeline stage that: + // - Takes the raw buffer (hidden ID 0) as input + // - Produces two outputs (hidden IDs 1 and 2) + // - These outputs become logical dimensions 0 and 1 + +**Step 2: Define the New Transformation Stage** + +To get from [2, 6] to [2, 2, 3], we need: + +- **For logical dimension 0 (length 2)**: Preserve it with PassThroughTransform +- **For logical dimension 1 (length 6)**: Split it with UnmergeTransform([2, 3]) + +**Step 3: Apply Transformation** + +.. code-block:: cpp + + // Create the transformed descriptor + auto transformed_descriptor = transform_tensor_descriptor( + base_descriptor, + make_tuple( + make_pass_through_transform(2), // For dim 0 + make_unmerge_transform(make_tuple(2, 3)) // For dim 1 + ), + make_tuple(sequence<0>{}, sequence<1>{}), // Input mapping + make_tuple(sequence<0>{}, sequence<1, 2>{}) // Output mapping + ); + + // Result: A [2, 2, 3] view of the same data + +Analysis of the Final Pipeline +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Transform Pipeline" + T0["Transform 0
Base Unmerge
Input: [0]
Output: [1,2]"] + T1["Transform 1
PassThrough
Input: [1]
Output: [3]"] + T2["Transform 2
Unmerge
Input: [2]
Output: [4,5]"] + end + + subgraph "Hidden Dimensions" + H0["Hidden ID 0
Raw Buffer"] + H1["Hidden ID 1
Dim 0 (size 2)"] + H2["Hidden ID 2
Dim 1 (size 6)"] + H3["Hidden ID 3
Final Dim 0"] + H4["Hidden ID 4
Final Dim 1"] + H5["Hidden ID 5
Final Dim 2"] + end + + H0 --> T0 + T0 --> H1 + T0 --> H2 + H1 --> T1 + H2 --> T2 + T1 --> H3 + T2 --> H4 + T2 --> H5 + + style H0 fill:#ffebee,stroke:#d32f2f,stroke-width:2px + style H3 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style H4 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style H5 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + +.. image:: diagrams/descriptors_2.svg + :alt: Diagram + :align: center + +.. image:: diagrams/descriptors_2.svg + :alt: Diagram + :align: center + +The pipeline now has three stages: + +1. **Base UnmergeTransform**: Converts raw buffer to [2, 6] layout +2. **PassThroughTransform**: Preserves the first dimension +3. **UnmergeTransform**: Splits the second dimension into [2, 3] + +5D to 3D Block Transformation +----------------------------------------------------- + +These concepts are critical in :ref:`GPU programming `. This example transforms a 5D tensor representing a GPU thread block's workload into a simpler 3D view using MergeTransform. See :ref:`ck_tile_thread_mapping` for thread distribution details. + +.. code-block:: cpp + + using namespace ck_tile; + + // Define parameters (typical for a GPU block) + constexpr auto Block_M = 256; + constexpr auto NumWarps = 8; + constexpr auto WarpSize = 64; + constexpr auto KVector = 4; + constexpr auto wavesPerK = 2; + constexpr auto wavesPerM = NumWarps / wavesPerK; + constexpr auto NumIssues = Block_M / wavesPerM; + + // Create the base 5D descriptor + auto base_descriptor = make_naive_tensor_descriptor_packed( + make_tuple(NumIssues, wavesPerM, wavesPerK, WarpSize, KVector) + ); + + // Transform to 3D by merging dimensions + auto transformed_descriptor = transform_tensor_descriptor( + base_descriptor, + make_tuple( + make_pass_through_transform(NumIssues), + make_merge_transform(make_tuple(wavesPerM, wavesPerK)), + make_merge_transform(make_tuple(WarpSize, KVector)) + ), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}) + ); + + // Result: [NumIssues, wavesPerM*wavesPerK, WarpSize*KVector] + // This simplifies thread block management while preserving data layout + +Common Descriptor Patterns +-------------------------- + +Matrix Transposition +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Create a transposed view of a matrix + auto transposed = transform_tensor_descriptor( + original_matrix, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(M) + ), + make_tuple(sequence<1>{}, sequence<0>{}), // Swap dimensions + make_tuple(sequence<0>{}, sequence<1>{}) + ); + +Padding for Convolution +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + +// Add padding to spatial dimensions + auto padded = transform_tensor_descriptor( + input_tensor, + make_tuple( + make_pass_through_transform(N), // Batch + make_pass_through_transform(C), // Channel + make_pad_transform(H, pad_h, pad_h), // Height + make_pad_transform(W, pad_w, pad_w) // Width + ), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}) + ); + +For a complete convolution example, see :ref:`ck_tile_convolution_example`. + +Tensor Slicing +~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Extract a sub-tensor + auto slice = transform_tensor_descriptor( + full_tensor, + make_tuple( + make_slice_transform(M, start_m, end_m), + make_slice_transform(N, start_n, end_n) + ), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + +Key Concepts Summary +-------------------- + +TensorDescriptors provide a key abstraction for tensor manipulation: + +- **Pipeline Architecture**: Each descriptor is a transformation pipeline +- **Zero-Copy Views**: All transformations are logical, no data movement +- **Composability**: Complex layouts built from simple transforms +- **GPU Optimization**: Designed for efficient GPU memory access patterns + +Important principles: + +1. **Always Handle All Dimensions**: When transforming, provide a transform for each input dimension +2. **Hidden Dimension IDs**: Track the flow of data through the pipeline +3. **Compile-Time Resolution**: All transformations resolved at compile time +4. **Type Safety**: Template metaprogramming ensures correctness + +Performance Considerations +-------------------------- + +When designing tensor descriptors for GPU kernels: + +1. **Memory Coalescing**: Ensure contiguous threads access contiguous memory +2. **Bank Conflicts**: Avoid patterns that cause :ref:`shared memory conflicts ` +3. **Alignment**: Use aligned layouts for better memory throughput +4. **Padding**: Strategic padding can improve access patterns. Ssee :ref:`ck_tile_lds_index_swapping` for advanced techniques. + +Next Steps +---------- + +- :ref:`ck_tile_tile_window` - Using descriptors for efficient data loading +- :ref:`ck_tile_tile_distribution` - How descriptors enable automatic work distribution +- :ref:`ck_tile_convolution_example` - Real-world application of complex descriptors +- :ref:`ck_tile_static_distributed_tensor` - Managing distributed tensors with descriptors +- :ref:`ck_tile_gemm_optimization` - GEMM kernels using descriptor transformations diff --git a/docs/conceptual/ck_tile/diagrams/adaptors_1.svg b/docs/conceptual/ck_tile/diagrams/adaptors_1.svg new file mode 100644 index 0000000000..e7ab20b093 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/adaptors_1.svg @@ -0,0 +1 @@ +

Adaptor Composition

Chained Transforms

Input
2D

Transform A
(e.g., Merge)

Intermediate
1D

Transform B
(e.g., Pad)

Output
1D Padded

Single Transform

Input Coords
[0,1,2]

Transform
(e.g., Transpose)

Output Coords
[2,0,1]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/adaptors_2.svg b/docs/conceptual/ck_tile/diagrams/adaptors_2.svg new file mode 100644 index 0000000000..417ff1b19c --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/adaptors_2.svg @@ -0,0 +1 @@ +

Adaptor Chaining Flow

Chained Result

Adaptor 2

Adaptor 1

Input 2D
Bottom[0,1]

Bottom Dims
[0,1]

Transform:
Merge[2,3]

Top Dims
[0]

Bottom Dims
[0]

Transform:
Unmerge[2,3]

Top Dims
[0,1]

Output 2D
Top[0,1]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/buffer_views_1.svg b/docs/conceptual/ck_tile/diagrams/buffer_views_1.svg new file mode 100644 index 0000000000..fb696c9e42 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/buffer_views_1.svg @@ -0,0 +1 @@ +

Usage Pattern

1. Load tile from Global → LDS
2. Load working set LDS → VGPR
3. Compute in VGPR
4. Store results VGPR → LDS
5. Reduce in LDS
6. Write final LDS → Global

Compute Flow

Global Memory
Input Data

LDS
Tile Cache

VGPR
Working Set

Compute
Operations

LDS
Reduction

Global Memory
Output Data

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/buffer_views_2.svg b/docs/conceptual/ck_tile/diagrams/buffer_views_2.svg new file mode 100644 index 0000000000..7a58311b33 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/buffer_views_2.svg @@ -0,0 +1 @@ +

Performance Impact

Vectorized Access (1 instruction)

Scalar Access (4 instructions)

Load float[0]

Register 1

Load float[1]

Register 2

Load float[2]

Register 3

Load float[3]

Register 4

Load float4[0]

Vector Register
(4 floats)

4x fewer instructions
Better memory bandwidth
Reduced latency

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/buffer_views_3.svg b/docs/conceptual/ck_tile/diagrams/buffer_views_3.svg new file mode 100644 index 0000000000..8e20da9fa0 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/buffer_views_3.svg @@ -0,0 +1 @@ +

Output

Processing

Input Parameters

Yes

No

Yes

No

Offset
(e.g., 5)

Valid Flag
(optional)

Bounds Check
offset < buffer_size?

Flag Check
valid_flag == True?

Access Memory
buffer[offset]

Valid Result
Return value

Invalid Result
Return 0 or default

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/buffer_views_4.svg b/docs/conceptual/ck_tile/diagrams/buffer_views_4.svg new file mode 100644 index 0000000000..f0b04d283b --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/buffer_views_4.svg @@ -0,0 +1 @@ +

Atomic Operation (Thread-Safe)

Thread 1: atomic_add(5)

Hardware ensures
serialization

Thread 2: atomic_add(3)

Final value: 18 ✓
(Both updates applied)

Non-Atomic Operation (Race Condition)

Thread 1: Read value (10)

Thread 1: Add 5 (15)

Thread 2: Read value (10)

Thread 2: Add 3 (13)

Thread 1: Write 15

Thread 2: Write 13

Final value: 13 ❌
(Lost update from Thread 1)

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/convolution_example.svg b/docs/conceptual/ck_tile/diagrams/convolution_example.svg new file mode 100644 index 0000000000..4a86641997 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/convolution_example.svg @@ -0,0 +1 @@ +

Im2col Optimization

Convolution Process

Input Image
6×6

Kernel
3×3

Sliding Window
Extract 3×3 patches

Dot Product
Element-wise multiply & sum

Output
4×4

Windows Matrix
16×9
(all patches)

Kernel Flattened
9×1

Matrix Multiply
W @ K

Output Flattened
16×1

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_movement.svg b/docs/conceptual/ck_tile/diagrams/coordinate_movement.svg new file mode 100644 index 0000000000..18f190d646 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_movement.svg @@ -0,0 +1 @@ +

Movement Example

Start: [1,1]
Offset: 5

Move [0,1]
→ [1,2]
Offset: 6

Move [1,0]
→ [2,2]
Offset: 10

Move [1,1]
→ [3,3]
Offset: 15

Coordinate Movement System

TensorCoordinate
Position + Descriptor Context

move_coordinate()
Efficient Navigation

TensorAdaptorCoordinate
Position + Transform Context

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_1.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_1.svg new file mode 100644 index 0000000000..8890aa2362 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_1.svg @@ -0,0 +1 @@ +

Transformations

Coordinate Spaces Overview

P-space
Thread Identification
Which thread am I?

Y-space
Logical Tile
Which element in my tile?

X-space
Physical Tensor
Where in the tensor?

R-space
Replication
Data sharing pattern

D-space
Linear Storage
Memory address

P + Y → X
Thread + Element → Position

X → D
Position → Address

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_2.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_2.svg new file mode 100644 index 0000000000..765318910a --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_2.svg @@ -0,0 +1 @@ +

P-space Mapping

GPU Thread Hierarchy

Block

Warp 1

Warp 0

Thread 0
P=[0,0]

Thread 1
P=[0,1]

Thread 2
P=[0,2]

...

Thread 31
P=[0,31]

Thread 32
P=[1,0]

Thread 33
P=[1,1]

...

Thread 63
P=[1,31]

Warp 2...

Warp 7

P-coordinates = [warp_id, lane_id]
or
P-coordinates = [block_x, block_y, thread_x, thread_y]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_3.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_3.svg new file mode 100644 index 0000000000..47846dfe4b --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_3.svg @@ -0,0 +1 @@ +

Example: 4 Threads

Y-space Structure

Thread's Tile (2x2 elements)

Y=[0,0]
Element 0

Y=[0,1]
Element 1

Y=[1,0]
Element 2

Y=[1,1]
Element 3

Each thread processes
the same Y-space pattern
but at different X locations

Thread 0
P=[0,0]

Thread 1
P=[0,1]

Thread 2
P=[1,0]

Thread 3
P=[1,1]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_4.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_4.svg new file mode 100644 index 0000000000..3a9f04c73d --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_4.svg @@ -0,0 +1 @@ +

Example

Output

Transformation

Input

P-coordinates
Thread identity
P=[1,0]

Y-coordinates
Element in tile
Y=[0,1]

P + Y → X
Base position + Offset

X-coordinates
Tensor position
X=[2,1]

Thread P=[1,0] at base (2,0)
Element Y=[0,1] adds offset (0,1)
Result X=[2,1] in tensor

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_5.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_5.svg new file mode 100644 index 0000000000..f91d8b39ef --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_5.svg @@ -0,0 +1 @@ +

D-coordinate

Layout Options

X-coordinates

X = [2, 3]
2D Position

Row-Major
D = 2×width + 3

Column-Major
D = 3×height + 2

Blocked
Complex pattern

D = 11
Linear Address

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/coordinate_systems_6.svg b/docs/conceptual/ck_tile/diagrams/coordinate_systems_6.svg new file mode 100644 index 0000000000..0e0275457a --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/coordinate_systems_6.svg @@ -0,0 +1 @@ +

Step 5: Memory Access

Step 4: Linearization

Step 3: P+Y Transformation

Step 2: Work Assignment

Step 1: Thread Identification

Thread ID = 5

P-coordinates
P = [0, 5]
(warp 0, lane 5)

Y-coordinates
Y = [1, 0]
(element in tile)

P + Y → X
Thread position + Element offset

X-coordinates
X = [1, 5]
(tensor position)

X → D
Row-major: D = x₀ × width + x₁

D-coordinate
D = 13
(memory address)

Hardware accesses
memory[13]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/descriptors_1.svg b/docs/conceptual/ck_tile/diagrams/descriptors_1.svg new file mode 100644 index 0000000000..a46b34e45d --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/descriptors_1.svg @@ -0,0 +1 @@ +

Same Data

Pipeline Stages

Stage 1
Base Layout
[M, N]

Stage 2
Transform
Unmerge

Stage 3
New View
[M1, M2, N]

Stage N
Final View
[...]

Physical Memory
No data movement

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/descriptors_2.svg b/docs/conceptual/ck_tile/diagrams/descriptors_2.svg new file mode 100644 index 0000000000..f9ebb053c0 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/descriptors_2.svg @@ -0,0 +1 @@ +

Hidden Dimensions

Transform Pipeline

Transform 0
Base Unmerge
Input: [0]
Output: [1,2]

Transform 1
PassThrough
Input: [1]
Output: [3]

Transform 2
Unmerge
Input: [2]
Output: [4,5]

Hidden ID 0
Raw Buffer

Hidden ID 1
Dim 0 (size 2)

Hidden ID 2
Dim 1 (size 6)

Hidden ID 3
Final Dim 0

Hidden ID 4
Final Dim 1

Hidden ID 5
Final Dim 2

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/encoding_internals_1.svg b/docs/conceptual/ck_tile/diagrams/encoding_internals_1.svg new file mode 100644 index 0000000000..41647a8c0e --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/encoding_internals_1.svg @@ -0,0 +1 @@ +

Transformation Chain

Generated Components

Encoding Components

R-space Lengths
Replication dimensions

H-space Lengths
Hierarchical decomposition
[[2,2],[2,2]]

P→RH Mappings
Thread to hierarchy
Major/Minor

Y→RH Mappings
Element to hierarchy
Major/Minor

ps_ys_to_xs_adaptor
Coordinate transformer

ys_to_d_descriptor
Memory linearizer

Encoding
Original specification

Replicate
Transform

Unmerge
Transform

Merge
Transform

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/encoding_internals_2.svg b/docs/conceptual/ck_tile/diagrams/encoding_internals_2.svg new file mode 100644 index 0000000000..4032376a6a --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/encoding_internals_2.svg @@ -0,0 +1 @@ +

Output

Transformation Pipeline

Input Coordinates

P-coordinates
[warp_id, lane_id]

Y-coordinates
[y0, y1, y2, y3]

Combine P+Y

Replicate
Transform
(if R-dims exist)

Unmerge
Transform
(break into H-dims)

Merge
Transform
(combine to X-dims)

X-coordinates
[x0, x1]
Tensor position

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/introduction_motivation_1.svg b/docs/conceptual/ck_tile/diagrams/introduction_motivation_1.svg new file mode 100644 index 0000000000..55253de744 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/introduction_motivation_1.svg @@ -0,0 +1 @@ +

Tile Distribution Pattern (Efficient)

Memory_TD

Threads_TD

Mem[0]

Thread 0

Mem[1]

Thread 1

Mem[2]

Mem[3]

Thread 2

Mem[4]

Mem[5]

Thread 3

Mem[6]

Mem[7]

Random Access Pattern (Inefficient)

Memory

Threads

Mem[0]

Thread 0

Mem[23]

Thread 1

Mem[7]

Thread 2

Mem[47]

Thread 3

Mem[15]

Mem[31]

Mem[39]

Mem[55]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/introduction_motivation_2.svg b/docs/conceptual/ck_tile/diagrams/introduction_motivation_2.svg new file mode 100644 index 0000000000..524b6b2d40 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/introduction_motivation_2.svg @@ -0,0 +1 @@ +

Transformations

Coordinate Spaces

P-space
Thread Position
(thread_x, thread_y,
warp_id, block_id)

Y-space
Local Data
(y0, y1, y2, y3)

X-space
Global Position
(x0, x1)

D-space
Memory Address
(linearized)

P + Y → X
Thread data mapping

X → D
Memory linearization

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/lds_index_swapping_1.svg b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_1.svg new file mode 100644 index 0000000000..26bf6ec7f5 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_1.svg @@ -0,0 +1 @@ +

Update K0 with XOR transformation

XOR Transform

3D LDS coordinate [K0, M, K1]

KPerBlock/KPack * MLdsLayer
K0

MPerBlock/MLdsLayer
M

KPack
K1

make_xor_transform

KPerBlock/KPack * MLdsLayer
K0'

MPerBlock/MLdsLayer
M

KPack
K1

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/lds_index_swapping_2.svg b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_2.svg new file mode 100644 index 0000000000..0b02bce106 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_2.svg @@ -0,0 +1 @@ +

4D intermediate transformation space

Unmerge into 2 components

3D LDS coordinate [K0', M, K1]

KPerBlock/KPack * MLdsLayer
K0'

MPerBlock/MLdsLayer
M

KPack
K1

make_unmerge_transform

MLdsLayer
L

MPerBlock/MLdsLayer
M

KPerBlock/KPack
K0''

KPack
K1

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/lds_index_swapping_3.svg b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_3.svg new file mode 100644 index 0000000000..378c0d35d0 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/lds_index_swapping_3.svg @@ -0,0 +1 @@ +

Transformed 2D coordinates [M', K']

Merge into 1 component

Merge into 1 component

4D LDS Coordinates [L, M, K0'', K1]

MLdsLayer
L

MPerBlock/MLdsLayer
M

KPerBlock/KPack
K0''

KPack
K1

make_merge_transform

make_merge_transform

MPerBlock
M'

KPerBlock
K'

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/load_store_traits_1.svg b/docs/conceptual/ck_tile/diagrams/load_store_traits_1.svg new file mode 100644 index 0000000000..51be25c0b2 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/load_store_traits_1.svg @@ -0,0 +1 @@ +

Yes

No

Analyze Distribution

Check Each Dimension

Calculate Stride

Stride == 1?

Candidate for Vectorization

Skip Dimension

Check Alignment

Check Vector Size

Score Dimension

Select Best Dimension

Configure Vector Access

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/load_store_traits_2.svg b/docs/conceptual/ck_tile/diagrams/load_store_traits_2.svg new file mode 100644 index 0000000000..48b6bff271 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/load_store_traits_2.svg @@ -0,0 +1 @@ +

Snake Pattern

0→1→2→3

7←6←5←4

Cache hit!

8→9→10→11

Linear Traversal

0→1→2→3

4→5→6→7

Cache miss

8→9→10→11

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/space_filling_curve.svg b/docs/conceptual/ck_tile/diagrams/space_filling_curve.svg new file mode 100644 index 0000000000..11b0ceda5b --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/space_filling_curve.svg @@ -0,0 +1 @@ +

Snake Pattern

Row 0: →

Row 1: ←

Row 2: →

Continue

Linear Pattern

Row 0: →

Jump back

Row 1: →

Row 2: →

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/static_distributed_tensor.svg b/docs/conceptual/ck_tile/diagrams/static_distributed_tensor.svg new file mode 100644 index 0000000000..6ce7e3c0c8 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/static_distributed_tensor.svg @@ -0,0 +1 @@ +

Global Tensor 64x64

Thread Block 16x16

Thread 0,0
Elements 0:3,0:3

Thread 0,1
Elements 0:3,4:7

Thread 1,0
Elements 4:7,0:3

...

Local Array
16 elements

Local Array
16 elements

Local Array
16 elements

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/sweep_tile_1.svg b/docs/conceptual/ck_tile/diagrams/sweep_tile_1.svg new file mode 100644 index 0000000000..4f145c81af --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/sweep_tile_1.svg @@ -0,0 +1 @@ +

Computation

Y-Sweep

X-Tile (Reused)

X data loaded once
Stays in registers

Y position 0

Y position 1

Y position 2

Y position N

Process(X, Y)

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/sweep_tile_2.svg b/docs/conceptual/ck_tile/diagrams/sweep_tile_2.svg new file mode 100644 index 0000000000..1ce4e41241 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/sweep_tile_2.svg @@ -0,0 +1 @@ +

Sweep Approach

Load X[0]

Process with
Y[0], Y[1], Y[2]

Load Y[0,1,2]

X loaded once!

Traditional Approach

Load X[0]

Process

Load Y[0]

Load X[0]

Process

Load Y[1]

Load X[0]

Process

Load Y[2]

X loaded 3 times!

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/sweep_tile_3.svg b/docs/conceptual/ck_tile/diagrams/sweep_tile_3.svg new file mode 100644 index 0000000000..10f419cede --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/sweep_tile_3.svg @@ -0,0 +1 @@ +

Use Cases

Sweep Performance Benefits

Zero runtime overhead
Compile-time unrolling

Perfect memory coalescing
Sequential access patterns

Automatic vectorization
Compiler optimizations

Register reuse
X data stays in VGPR

Matrix Multiplication
Reuse A columns

Convolution
Reuse filter weights

Reduction
Accumulate over Y

Broadcast
Apply X to all Y

High Performance

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/sweep_tile_4.svg b/docs/conceptual/ck_tile/diagrams/sweep_tile_4.svg new file mode 100644 index 0000000000..50530522cd --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/sweep_tile_4.svg @@ -0,0 +1 @@ +

Complete Workflow

TileDistribution
Define data layout

TileWindow
Create view

DistributedTensor
Load X data

SweepTile
Iterate Y positions

Results
Store outputs

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_coordinates_1.svg b/docs/conceptual/ck_tile/diagrams/tensor_coordinates_1.svg new file mode 100644 index 0000000000..ee4206f4a2 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_coordinates_1.svg @@ -0,0 +1 @@ +

Usage Context

MultiIndex Structure

MultiIndex
Container for N integers

Dimension 0

Dimension 1

Dimension 2

Dimension N-1

Transforms

Adaptors

Tensors

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_coordinates_2.svg b/docs/conceptual/ck_tile/diagrams/tensor_coordinates_2.svg new file mode 100644 index 0000000000..efada63f93 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_coordinates_2.svg @@ -0,0 +1 @@ +

Example: 3D Tensor Access

3D Tensor
shape=[4,5,6]

MultiIndex(3, [1,2,3])

Element at
position [1,2,3]

Coordinate Flow

User Input
[1, 2, 3]

MultiIndex
Storage

Transform
Processing

MultiIndex
Output

Tensor Access
element(coord)

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_views_1.svg b/docs/conceptual/ck_tile/diagrams/tensor_views_1.svg new file mode 100644 index 0000000000..41338c8902 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_views_1.svg @@ -0,0 +1 @@ +

Logical View

Tensor Layer

Access Layer

Memory Foundation

Flat Memory Array
0 1 2 3 4 5 6 7 8 9 10 11

BufferView
Linear Memory Access

TensorDescriptor
Shape & Stride Info

TensorView
Multi-dimensional Access

2D Matrix View
[3×4]
[[0,1,2,3]
[4,5,6,7]
[8,9,10,11]]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_views_2.svg b/docs/conceptual/ck_tile/diagrams/tensor_views_2.svg new file mode 100644 index 0000000000..f57636d293 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_views_2.svg @@ -0,0 +1 @@ +

Result

TensorView Processing

User Input

Valid

Coordinate
(1, 2)

Shape Check
row < 3?
col < 4?

Apply Strides
offset = 1×4 + 2×1

BufferView Access
buffer[6]

Value: 6

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_views_3.svg b/docs/conceptual/ck_tile/diagrams/tensor_views_3.svg new file mode 100644 index 0000000000..df13db0c0d --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_views_3.svg @@ -0,0 +1 @@ +

Custom Stride (Transposed View)

Memory: [0,1,2,3,4,5,6,7,8,9,10,11]
Shape: (4,3)
Strides: (1,4)

[[0, 4, 8]
[1, 5, 9]
[2, 6, 10]
[3, 7, 11]]

Column-Major Layout (Fortran-style)

Memory: [0,3,6,9,1,4,7,10,2,5,8,11]
Shape: (3,4)
Strides: (1,3)

[[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9, 10, 11]]

Row-Major Layout (C-style)

Memory: [0,1,2,3,4,5,6,7,8,9,10,11]
Shape: (3,4)
Strides: (4,1)

[[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9, 10, 11]]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_views_4.svg b/docs/conceptual/ck_tile/diagrams/tensor_views_4.svg new file mode 100644 index 0000000000..8e521229cf --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_views_4.svg @@ -0,0 +1 @@ +

Optimization Strategies

Memory Access Patterns

Sequential Access
(Good cache usage)

Strided Access
(May cause cache misses)

Random Access
(Poor cache usage)

Use row-major for row iteration

Use col-major for column iteration

Minimize stride between accesses

Vectorize when possible

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tensor_views_5.svg b/docs/conceptual/ck_tile/diagrams/tensor_views_5.svg new file mode 100644 index 0000000000..2faec8d8d3 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tensor_views_5.svg @@ -0,0 +1 @@ +

Use Cases

TensorView

BufferView

Linear indexing only

buffer[5]

No shape information

Direct memory access

Multi-dimensional indexing

tensor(1, 2)

Shape-aware operations

Coordinate transformations

BufferView: Low-level memory ops

TensorView: Matrix/tensor algorithms

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/thread_mapping_1.svg b/docs/conceptual/ck_tile/diagrams/thread_mapping_1.svg new file mode 100644 index 0000000000..119f631829 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/thread_mapping_1.svg @@ -0,0 +1 @@ +

P-space Mapping

Thread Identification

GPU Device

Thread Block

Warp 0

Warp 1

Thread 32
lane_id=0

Thread 33
lane_id=1

...

Thread 63
lane_id=31

Thread 0
lane_id=0

Thread 1
lane_id=1

...

Thread 31
lane_id=31

Warp 2

...

Warp 7

Thread ID = blockIdx.x * blockDim.x + threadIdx.x

Warp ID = threadIdx.x / 32

Lane ID = threadIdx.x % 32

P-coordinates
NDimP=1: [thread_id]
NDimP=2: [warp_id, lane_id]

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/thread_mapping_2.svg b/docs/conceptual/ck_tile/diagrams/thread_mapping_2.svg new file mode 100644 index 0000000000..f523de1c8c --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/thread_mapping_2.svg @@ -0,0 +1 @@ +

Thread to Data Mapping

Memory Access

Data Tiles

Thread Grid

Coalesced Access
Adjacent threads → Adjacent memory

Thread[0,0]
Warp 0

Data[0:4, 0:4]
16 elements

Thread[0,1]
Warp 0

Data[0:4, 4:8]
16 elements

Thread[1,0]
Warp 1

Data[4:8, 0:4]
16 elements

Thread[1,1]
Warp 1

Data[4:8, 4:8]
16 elements

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_1.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_1.svg new file mode 100644 index 0000000000..19e7140013 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_1.svg @@ -0,0 +1 @@ +

GPU Execution

Coordinate Spaces

Logical View

Tensor
Multi-dimensional data

TileDistribution
Work assignment

TileWindow
Data view

X: Physical tensor coords

Y: Tile pattern coords

P: Processing element coords

R: Replication coords (optional)

Warps
32 threads each

Lanes
Thread within warp

Registers
Thread-local storage

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_2.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_2.svg new file mode 100644 index 0000000000..6f588a46c4 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_2.svg @@ -0,0 +1 @@ +

Output

Transformation Pipeline

Input

Thread Coordinates
(warpId, laneId)

P → Y
Thread to pattern

Y → X
Pattern to physical

Y → D
Pattern to register

Memory Coordinates
Global addresses

Register Indices
Local storage

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_3.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_3.svg new file mode 100644 index 0000000000..0974e138fd --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_3.svg @@ -0,0 +1 @@ +

Memory Pattern

Thread Assignment

Problem Space (256×256 Matrix)

Full Matrix
65,536 elements

Tile 1
32×32

Tile 2
32×32

Tile N
32×32

Warp 0
32 threads

Warp 1
32 threads

Lane 0-31
Individual threads

Coalesced Access
Sequential addresses
No bank conflicts

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_4.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_4.svg new file mode 100644 index 0000000000..894151380d --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_4.svg @@ -0,0 +1 @@ +

Level 3: Thread Distribution

Level 2: Warp Distribution

Level 1: Block Distribution

Thread Block
256 threads

Block Tile 1
64×64

Block Tile 2
64×64

Warp
32 threads

Warp Tile 1
16×16

Warp Tile 2
16×16

Thread

Thread Tile
2×2

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_5.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_5.svg new file mode 100644 index 0000000000..2e46ee58cf --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_5.svg @@ -0,0 +1 @@ +

Memory Access

Per Thread

Thread Grid (32×32)

Matrix C (128×128)

16,384 elements

1,024 threads

4×4 tile
16 elements

Coalesced reads
Efficient writes
No conflicts

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_6.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_6.svg new file mode 100644 index 0000000000..2195465e60 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_6.svg @@ -0,0 +1 @@ +

Output

Stage 3

Stage 2

Stage 1

Input

Thread ID
(0-1023)

P-coordinates
(warp, lane)

Y-coordinates
(tile position)

X-coordinates
(tensor indices)

Memory addresses
Register indices

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_distribution_7.svg b/docs/conceptual/ck_tile/diagrams/tile_distribution_7.svg new file mode 100644 index 0000000000..e9ec5a5780 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_distribution_7.svg @@ -0,0 +1 @@ +

Performance

With TileDistribution

Manual Implementation

Calculate indices manually

Handle boundary conditions

Ensure coalescing

Manage bank conflicts

~200 lines of code

make_tile_distribution()

Automatic optimization

~10 lines of code

Same performance

Fewer bugs

Portable across GPUs

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_window_1.svg b/docs/conceptual/ck_tile/diagrams/tile_window_1.svg new file mode 100644 index 0000000000..6c2203c332 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_window_1.svg @@ -0,0 +1 @@ +

Optimizations

Operations

Components

TensorView
Data source

TileDistribution
Thread mapping

TileWindow
Access gateway

LoadStoreTraits
Access optimizer

DistributedTensor
Register storage

Load
Global → Registers

Compute
In registers

Store
Registers → Global

Coalescing
Adjacent access

Vectorization
Multi-element ops

Bank conflict
avoidance

Space-filling
curve traversal

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_window_2.svg b/docs/conceptual/ck_tile/diagrams/tile_window_2.svg new file mode 100644 index 0000000000..60ec2dd1ce --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_window_2.svg @@ -0,0 +1 @@ +

Snake Access Pattern

0,1,2,3

7,6,5,4

8,9,10,11

15,14,13,12

Linear Access Pattern

0,1,2,3

4,5,6,7

8,9,10,11

12,13,14,15

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_window_3.svg b/docs/conceptual/ck_tile/diagrams/tile_window_3.svg new file mode 100644 index 0000000000..9b2293d295 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_window_3.svg @@ -0,0 +1 @@ +

Step 3: Load Data

Step 2: Apply Distribution

Step 1: Create Window

load()

Tensor
[256, 256]

Origin
(64, 64)

Window Size
[32, 32]

TileDistribution
Thread mapping

TileWindow
Created

Global Memory
Window region

Registers
Distributed tensor

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_window_4.svg b/docs/conceptual/ck_tile/diagrams/tile_window_4.svg new file mode 100644 index 0000000000..f031c7778b --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_window_4.svg @@ -0,0 +1 @@ +

Result

Memory Transaction

Vectorization

Load Analysis

Analyze access pattern
Detect coalescing opportunities

Scalar: 4 loads

Vector2: 2 loads

Vector4: 1 load

Coalesced access
32 threads → 1 transaction

Non-coalesced
32 threads → 32 transactions

Thread registers
Local data

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/tile_window_5.svg b/docs/conceptual/ck_tile/diagrams/tile_window_5.svg new file mode 100644 index 0000000000..16ae1d01cc --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/tile_window_5.svg @@ -0,0 +1 @@ +

Hardware Utilization

Memory Access Optimization

Vectorization
4x fewer transactions

Coalescing
32x bandwidth efficiency

Precomputation
Zero overhead addressing

Space-filling
Optimal cache usage

Memory Bandwidth
Near 100% utilization

Latency Hiding
Overlapped operations

Register Reuse
Minimal spills

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_1.svg b/docs/conceptual/ck_tile/diagrams/transforms_1.svg new file mode 100644 index 0000000000..3f00bbee54 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_1.svg @@ -0,0 +1 @@ +

Tensor Coordinate Transformation

Forward Transform

Inverse Transform

Same data,
different views

Same data,
different views

Lower Dimension Space
Source coordinate system

Upper Dimension Space
Target coordinate system

Linear Data in Memory
Layout determined by tensor
shape & strides

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_10.svg b/docs/conceptual/ck_tile/diagrams/transforms_10.svg new file mode 100644 index 0000000000..34f7b7c04b --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_10.svg @@ -0,0 +1 @@ +

XorTransform: 2D → 2D XOR Mapping

Forward Transform
apply XOR reverse

Inverse Transform
apply XOR mapping

XOR pattern
view

Normal
view

Lower Coordinate Space
2D: [4, 8]
XOR-transformed coords

Upper Coordinate Space
2D: [4, 8]
Normal coords

Same Tensor Data

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_11.svg b/docs/conceptual/ck_tile/diagrams/transforms_11.svg new file mode 100644 index 0000000000..688dcab9ca --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_11.svg @@ -0,0 +1 @@ +

SliceTransform: 1D → 1D Sub-region

Forward Transform
idx + slice_begin

Inverse Transform
idx - slice_begin

Full tensor
view

Sub-region
view

Lower Coordinate Space
1D: [0, 9] (original range)

Upper Coordinate Space
1D: [0, 4] (slice range)

Tensor Data in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_12.svg b/docs/conceptual/ck_tile/diagrams/transforms_12.svg new file mode 100644 index 0000000000..f754ba4964 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_12.svg @@ -0,0 +1 @@ +

ModuloTransform: 1D → 1D Cyclic

Forward Transform
idx * cycle_count

Inverse Transform
idx % modulus

Lower Coordinate Space
1D: [0, 3] (modulus range)

Upper Coordinate Space
1D: [0, 15] (full range)

Tensor Data in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_2.svg b/docs/conceptual/ck_tile/diagrams/transforms_2.svg new file mode 100644 index 0000000000..26b40010bb --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_2.svg @@ -0,0 +1 @@ +

Operations

Transform Types

EmbedTransform
Linear → Multi-D Strided

MergeTransform
Multi-D → Linear

UnmergeTransform
Linear → Multi-D

ReplicateTransform
0D → Multi-D Broadcast

OffsetTransform
Translation

PassThroughTransform
Identity

PadTransform
Boundaries

Forward
calculate_lower_index()

Backward
calculate_upper_index()

Update
update_lower_index()

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_3.svg b/docs/conceptual/ck_tile/diagrams/transforms_3.svg new file mode 100644 index 0000000000..acd9de4a23 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_3.svg @@ -0,0 +1 @@ +

MergeTransform: Multi-D → Linear

Forward Transform
2×5 + 3 = 13

Inverse Transform
13÷5=2, 13%5=3

Multi-dimensional
view

Linear
view

Lower Coordinate Space
2D: [4, 5]
Coord: (2, 3)

Upper Coordinate Space
1D Linear
Index: 13

Same Tensor Data
Layout: row-major
Size: 20 elements

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_4.svg b/docs/conceptual/ck_tile/diagrams/transforms_4.svg new file mode 100644 index 0000000000..0bbf78430a --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_4.svg @@ -0,0 +1 @@ +

UnmergeTransform: Linear → Multi-D

Forward Transform
14 = 1×8 + 3×2 + 0

Inverse Transform
linearize back

Linear
view

Multi-dimensional
view

Lower Coordinate Space
1D Linear
Index: 14

Upper Coordinate Space
3D: [3, 4, 2]
Coord: (1, 3, 0)

Same Tensor Data
Layout: row-major
Size: 24 elements

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_5.svg b/docs/conceptual/ck_tile/diagrams/transforms_5.svg new file mode 100644 index 0000000000..3f57a2b675 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_5.svg @@ -0,0 +1 @@ +

EmbedTransform: Linear → Multi-D Strided

Forward Transform
Strides: [12, 1]
14 ÷ 12 = 1, 14 % 12 = 2

Inverse Transform
1×12 + 2×1 = 14

Linear
index view

Multi-dimensional
strided view

Lower Coordinate Space
1D Linear
Index: 14

Upper Coordinate Space
2D: [2, 3]
Coord: (1, 2)

Linear Buffer in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_6.svg b/docs/conceptual/ck_tile/diagrams/transforms_6.svg new file mode 100644 index 0000000000..014fa90176 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_6.svg @@ -0,0 +1 @@ +

ReplicateTransform: 0D → Multi-D Broadcasting

Forward Transform
[] → (i,j) for any i,j

Inverse Transform
(i,j) → [] for any i,j

One scalar
value

Broadcasted view
at all positions

Lower Coordinate Space
0D: Scalar
Empty coordinate []

Upper Coordinate Space
2D: [3, 4]
All coords: (i, j)

Single Scalar Value

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_7.svg b/docs/conceptual/ck_tile/diagrams/transforms_7.svg new file mode 100644 index 0000000000..676196744d --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_7.svg @@ -0,0 +1 @@ +

OffsetTransform: 1D → 1D Translation

Forward Transform
idx → idx + 16

Inverse Transform
idx + 16 → idx

Lower
view

Upper
view

Lower Coordinate Space
1D: [0, 63]
Coord: index + offset

Upper Coordinate Space
1D: [0, 47]
Coord: index

Linear Buffer in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_8.svg b/docs/conceptual/ck_tile/diagrams/transforms_8.svg new file mode 100644 index 0000000000..ddb41be8fe --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_8.svg @@ -0,0 +1 @@ +

PassThroughTransform: 1D → 1D Identity

Perfect Identity
idx → idx

Perfect Identity
idx → idx

Same buffer
same view

Same buffer
same view

Lower Coordinate Space
1D: [0, 59]
Coord: index

Upper Coordinate Space
1D: [0, 59]
Coord: index

Linear Buffer in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/diagrams/transforms_9.svg b/docs/conceptual/ck_tile/diagrams/transforms_9.svg new file mode 100644 index 0000000000..5b219099b0 --- /dev/null +++ b/docs/conceptual/ck_tile/diagrams/transforms_9.svg @@ -0,0 +1 @@ +

PadTransform: 1D → 1D with Padding

Forward Transform
idx + left_pad

Inverse Transform
idx - left_pad

Original view

Padded view

Lower Coordinate Space
1D: [0, 2] (original data)

Upper Coordinate Space
1D: [0, 4] (with padding)

Tensor Data in Memory

\ No newline at end of file diff --git a/docs/conceptual/ck_tile/encoding_internals.rst b/docs/conceptual/ck_tile/encoding_internals.rst new file mode 100644 index 0000000000..499ec0bd4a --- /dev/null +++ b/docs/conceptual/ck_tile/encoding_internals.rst @@ -0,0 +1,489 @@ +.. meta:: + :description: CK Tile encoding internals documentation + :keywords: CK Tile, encoding, tile distribution, GPU programming, compile-time computation + +.. _ck_tile_encoding_internals: + +****************** +Encoding Internals +****************** + +Overview +======== + +The tile distribution encoding system represents the core mathematical framework that transforms high-level tensor distribution specifications into concrete, optimized GPU kernel implementations. This advanced compile-time machinery bridges the gap between abstract mathematical descriptions and executable coordinate transformations, enabling the Composable Kernel framework to generate highly efficient code for complex tensor operations. + +At its heart, the encoding system defines how multi-dimensional tensor data is distributed across GPU processing elements through a hierarchical decomposition scheme. By specifying relationships between different coordinate spaces of replication (R), hierarchical (H), partition (P), and yield (Y) dimension, the encoding provides a complete blueprint for data layout and access patterns that can be resolved entirely at compile time. This is the internal mechanism behind :ref:`ck_tile_tile_distribution`. See :ref:`ck_tile_coordinate_systems` for more information about coordinate spaces. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Encoding Components" + RS["R-space Lengths
Replication dimensions"] + HS["H-space Lengths
Hierarchical decomposition
[[2,2],[2,2]]"] + P2RH["P→RH Mappings
Thread to hierarchy
Major/Minor"] + Y2RH["Y→RH Mappings
Element to hierarchy
Major/Minor"] + end + + subgraph "Generated Components" + ADAPTOR["ps_ys_to_xs_adaptor
Coordinate transformer"] + DESC["ys_to_d_descriptor
Memory linearizer"] + ENC["Encoding
Original specification"] + end + + subgraph "Transformation Chain" + T1["Replicate
Transform"] + T2["Unmerge
Transform"] + T3["Merge
Transform"] + end + + RS --> T1 + HS --> T2 + P2RH --> ADAPTOR + Y2RH --> ADAPTOR + + T1 --> T2 + T2 --> T3 + T3 --> ADAPTOR + + HS --> DESC + Y2RH --> DESC + + style RS fill:#fce4ec,stroke:#c2185b,stroke-width:2px + style HS fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style ADAPTOR fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style DESC fill:#fff3e0,stroke:#f57c00,stroke-width:3px + + + +.. image:: diagrams/encoding_internals_1.svg + :alt: Diagram + :align: center + +Encoding Structure +================== + +The tile distribution encoding employs a template-based type system that captures the complete specification of tensor distribution patterns at compile time: + +.. code-block:: cpp + + template // Y to RH mapping (minor) + struct tile_distribution_encoding + { + // All computations resolved at compile time + static constexpr index_t NDimX = HsLengthss::size(); + static constexpr index_t NDimP = Ps2RHssMajor::size(); + static constexpr index_t NDimY = Ys2RHsMajor::size(); + static constexpr index_t NDimR = RsLengths::size(); + + // Static member functions for compile-time access + __host__ __device__ static constexpr auto get_rs_lengths() { + return RsLengths_{}; + } + + __host__ __device__ static constexpr auto get_hs_lengthss() { + return HsLengthss_{}; + } + + // Nested detail struct performs complex compile-time calculations + struct detail { + // Precomputed mappings and transformations + static constexpr auto get_h_dim_lengths_prefix_sum(); + static constexpr auto get_uniformed_idx_y_to_h(); + // ... compile-time computation ... + }; + }; + +Key Template Features +--------------------- + +1. **Template Metaprogramming**: All parameters are types, not values, enabling compile-time optimization +2. **Constexpr Functions**: Everything is computed at compile time +3. **Type Aliases**: Clean access to template parameters +4. **Static Member Functions**: No runtime overhead + +Parameter Breakdown +=================== + +R-Dimensions: Replication Specification +--------------------------------------- + +The ``RsLengths`` parameter defines dimensions that are replicated across processing units, enabling data sharing patterns essential for many tensor operations: + +.. code-block:: cpp + + // Example: GEMM with warp-level replication + using RsLengths = Sequence; + + // This creates replication pattern: + // - NWarpPerBlock warps share the same A data + // - MWarpPerBlock warps share the same B data + +Replication serves several purposes: + +- **Data Reuse**: Same input data needed by multiple output computations +- **Reduction Operations**: Multiple threads collaborate on single result +- **Memory Efficiency**: Reduces global memory bandwidth requirements + +H-Dimensions: Hierarchical Decomposition +---------------------------------------- + +The ``HsLengthss`` parameter represents hierarchical decomposition of tensor dimensions: + +.. code-block:: cpp + + // Example: Block-level GEMM decomposition + using HsLengthss = Tuple< + Sequence, // M-dimension + Sequence // N-dimension + >; + + // This creates hierarchy: + // - MRepeat: iterations per thread in M + // - MWarp: warps assigned to M + // - MThread: threads per warp for M + // - MVec: vector size for M + +The decomposition enables: + +- **Memory Coalescing**: Aligning with warp/thread organization +- **Register Blocking**: Tile sizes that fit in register file +- **Shared Memory Utilization**: Tiles that exploit data reuse + +P-Dimensions: Partition Mapping +------------------------------- + +The ``Ps2RHssMajor`` and ``Ps2RHssMinor`` parameters define work assignment: + +.. code-block:: cpp + + // Example: 2D thread block mapping + // P0 = warp_id, P1 = lane_id + using Ps2RHssMajor = Tuple< + Sequence<1>, // P0 maps to H1 (warp dimension) + Sequence<2> // P1 maps to H2 (thread dimension) + >; + using Ps2RHssMinor = Tuple< + Sequence<1>, // Use second component of H1 + Sequence<2> // Use third component of H2 + >; + +The mapping mechanism: + +- **Major Index**: Which RH-dimension group (0=R, 1-N=H) +- **Minor Index**: Component within that group + +Y-Dimensions: Logical View Mapping +---------------------------------- + +The ``Ys2RHsMajor`` and ``Ys2RHsMinor`` define the user-facing interface: + +.. code-block:: cpp + + // Example: 2D tile access pattern + using Ys2RHsMajor = Sequence<1, 1, 2, 2>; // Y→H mapping + using Ys2RHsMinor = Sequence<0, 1, 0, 1>; // Component selection + + // Creates 2x2 logical view: + // Y[0,0] → H1[0], H2[0] + // Y[0,1] → H1[1], H2[0] + // Y[1,0] → H1[0], H2[1] + // Y[1,1] → H1[1], H2[1] + +Transformation Pipeline +======================= + +The encoding generates a transformation pipeline that converts coordinates using the concepts from :ref:`ck_tile_transforms` and :ref:`ck_tile_adaptors`: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "Input Coordinates" + P["P-coordinates
[warp_id, lane_id]"] + Y["Y-coordinates
[y0, y1, y2, y3]"] + end + + subgraph "Transformation Pipeline" + C1["Combine P+Y"] + T1["Replicate
Transform
(if R-dims exist)"] + T2["Unmerge
Transform
(break into H-dims)"] + T3["Merge
Transform
(combine to X-dims)"] + end + + subgraph "Output" + X["X-coordinates
[x0, x1]
Tensor position"] + end + + P --> C1 + Y --> C1 + C1 --> T1 + T1 --> T2 + T2 --> T3 + T3 --> X + + style P fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style Y fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + +.. image:: diagrams/encoding_internals_2.svg + :alt: Diagram + :align: center + +Building the Transformation Chain +--------------------------------- + +.. code-block:: cpp + + template + __host__ __device__ auto make_ps_ys_to_xs_adaptor(const Encoding& encoding) + { + // Step 1: Create individual transforms + constexpr auto replicate_transform = make_replicate_transform( + encoding.get_rs_lengths()); + + constexpr auto unmerge_transform = make_unmerge_transform( + encoding.get_hs_lengthss()); + + constexpr auto merge_transform = make_merge_transform( + encoding.get_rhs_to_xs_mapping()); + + // Step 2: Chain transforms together + constexpr auto transform_chain = chain_transforms( + replicate_transform, + unmerge_transform, + merge_transform); + + // Step 3: Create adaptor with the chain + return make_tile_adaptor( + transform_chain, + encoding.get_lower_dimension_hidden_idss()); + } + +Transform Implementation Example +-------------------------------- + +.. code-block:: cpp + + // Replicate transform implementation + template + struct replicate_transform + { + static constexpr index_t num_of_upper_dimension = size(Lengths{}); + static constexpr index_t num_of_lower_dimension = 2 * num_of_upper_dimension; + + template + __host__ __device__ constexpr auto + calculate_lower_index(const UpperIndex& idx_upper) const + { + // Replicate each coordinate: [a,b] -> [a,b,0,0] + auto idx_lower = make_zero_multi_index(); + + static_for<0, num_of_upper_dimension, 1>{}([&](auto i) { + idx_lower(i) = idx_upper[i]; + idx_lower(i + num_of_upper_dimension) = 0; + }); + + return idx_lower; + } + }; + +Y to D Linearization +==================== + +The Y→D descriptor handles memory layout within each thread, building on :ref:`ck_tile_descriptors` concepts: + +.. code-block:: cpp + + template + struct ys_to_d_descriptor + { + static constexpr index_t num_of_dimension = size(YLengths{}); + + // Calculate linear offset from Y coordinates + template + __host__ __device__ constexpr index_t + calculate_offset(const YIndex& idx_y) const + { + index_t offset = 0; + + static_for<0, num_of_dimension, 1>{}([&](auto i) { + offset += idx_y[i] * YStrides{}[i]; + }); + + return offset; + } + + // Get element space size (total elements per thread) + __host__ __device__ static constexpr index_t + get_element_space_size() + { + return reduce_on_sequence( + YLengths{}, + multiplies{}, + number<1>{}); + } + }; + +Memory Layout Optimization +-------------------------- + +.. code-block:: cpp + + // Optimized layout for vector operations + template + struct make_ys_to_d_descriptor_for_gemm + { + // Layout: [M/VectorSize][N][VectorSize] + // This ensures vector loads are contiguous in memory + using type = tile_descriptor< + Sequence, + Sequence>; + }; + +Integration in Distributed Tensor +--------------------------------- + +This shows how the encoding integrates with :ref:`ck_tile_static_distributed_tensor`: + +.. code-block:: cpp + + template + struct static_distributed_tensor + { + using ys_to_d_descriptor = typename TileDistribution::ys_to_d_descriptor; + + // Thread-local storage + static constexpr index_t thread_buffer_size = + ys_to_d_descriptor::get_element_space_size(); + + DataType thread_buffer_[thread_buffer_size]; + + // Access element at Y coordinate + template + __host__ __device__ DataType& at(const YIndex& idx_y) + { + const index_t offset = ys_to_d_descriptor{}.calculate_offset(idx_y); + return thread_buffer_[offset]; + } + }; + +Practical Examples +================== + +Example 1: Simple 2x2 Distribution +---------------------------------- + +.. code-block:: cpp + + // No replication, simple hierarchy + using SimpleEncoding = tile_distribution_encoding< + Sequence<>, // rs_lengths: no replication + Tuple< // hs_lengthss: 2x2 hierarchy + Sequence<2>, + Sequence<2> + >, + Tuple, Sequence<>>, // ps_to_rhss_major + Tuple, Sequence<>>, // ps_to_rhss_minor + Sequence<1, 2>, // ys_to_rhs_major + Sequence<0, 0> // ys_to_rhs_minor + >; + +Example 2: GEMM Distribution +---------------------------- + +.. code-block:: cpp + + // Complex GEMM distribution with replication + template + using GemmBlockEncoding = tile_distribution_encoding< + Sequence<>, // No block-level replication + Tuple< // Hierarchical decomposition + Sequence, // M + Sequence // N + >, + Tuple< // Warp assignment + Sequence<1, 2>, // [warp_m, warp_n] + Sequence<> + >, + Tuple< + Sequence<1, 0>, // Major indices + Sequence<> + >, + Sequence<1, 1, 2, 2>, // Y mapping + Sequence<0, 1, 0, 1> // Y components + >; + +Performance Implications +======================== + +The encoding system is designed for maximum GPU performance. See :ref:`ck_tile_gpu_basics` for hardware fundamentals. + +Memory Access Patterns +---------------------- + +- **Coalescing**: Hierarchical decomposition ensures adjacent threads access adjacent memory +- **Bank Conflicts**: Careful dimension ordering prevents shared memory conflicts. See :ref:`ck_tile_lds_bank_conflicts` for more information. +- **Vectorization**: Natural support for vector loads and stores. See :ref:`ck_tile_load_store_traits` for more information. + +Register Efficiency +------------------- + +- **Optimal Allocation**: Y→D linearization minimizes register usage +- **Spill Avoidance**: Compile-time sizing prevents register spills +- **Reuse Patterns**: Encoding enables efficient register reuse + +Compile-Time Optimization +------------------------- + +.. code-block:: cpp + + // All encoding operations resolve at compile time + template + struct encoding_optimizer { + // Compute all derived values at compile time + static constexpr auto total_elements = /* computed */; + static constexpr auto access_pattern = /* computed */; + static constexpr auto memory_layout = /* computed */; + + // Generate optimized code paths + template + __device__ void apply_optimized(Func&& f) { + if constexpr (is_simple_pattern) { + // Direct access path + } else if constexpr (is_strided_pattern) { + // Strided access path + } else { + // General access path + } + } + }; + +Summary +======= + +The tile distribution encoding system demonstrates compile-time computation: + +- **Mathematical Foundation**: Complete specification through dimensional relationships +- **Zero Overhead**: All computations resolve at compile time +- **Composable Design**: Individual transforms compose into complex mappings +- **Hardware Alignment**: Natural mapping to GPU execution hierarchy +- **Performance Focus**: Every design decision optimizes for GPU efficiency + +The encoding internals show how CK Tile achieves practical performance. By leveraging C++ template metaprogramming and careful architectural design, the framework generates code that rivals hand-optimized implementations while maintaining clarity and composability. + +For practical examples of how the encoding system is used, see :ref:`ck_tile_thread_mapping`. For coordinate operations that build on these encodings, see :ref:`ck_tile_coordinate_movement`. diff --git a/docs/conceptual/ck_tile/hardware/gemm_optimization.rst b/docs/conceptual/ck_tile/hardware/gemm_optimization.rst new file mode 100644 index 0000000000..a31b6b7803 --- /dev/null +++ b/docs/conceptual/ck_tile/hardware/gemm_optimization.rst @@ -0,0 +1,385 @@ +.. meta:: + :description: Block GEMM optimization on MI300 using CK Tile + :keywords: GEMM, matrix multiplication, MI300, CK, Composable Kernel, GPU optimization + +.. _ck_tile_gemm_optimization: + +******************************************************************** +A Block GEMM on MI300 +******************************************************************** + +Introduction to GEMMs +===================== + +This document illustrates key concepts of implementing a block GEMM (General Matrix Multiplication) kernel on AMD's MI300 GPU. GEMM is a fundamental building block for many machine learning workloads, including attention mechanisms and Mixture of Experts (MoE) models. + +The problem addressed here is the standard matrix multiplication: :math:`C = A \cdot B`, where matrix A has dimensions **M x K** and matrix B has dimensions **K x N**. The resulting matrix C will have dimensions **M x N**. For simplicity and a better memory access pattern, it will be assumed that matrix B is in a column-major format, which means its shape is logically represented as **N x K**. + +Format and Dimensions +===================== + +The first step in designing the kernel is to select the data format and dimensions. + +Data Format: bf16 +----------------- + +While ``float32`` is a common choice, its high precision is computationally expensive and can be unnecessary for model convergence. A more suitable alternative is a half-precision floating-point format. We will use **bfloat16 (bf16)**. + +Bfloat16 is a 16-bit format that uses the same 8-bit exponent as ``float32``. This allows it to have the same dynamic range, which is critical for avoiding overflow and underflow during training. The key difference is that ``bf16`` uses only 7 bits for the mantissa (versus 23 bits in ``float32``), which makes it functionally equivalent to a simple right bit-shift of a 32-bit float: ``(float32 >> 16)``. + +Dimensions: M=4864, N=4096 +-------------------------- + +To maximize hardware utilization, dimensions are used that utilize the GPU's resources well. For this example, **M = 4864** and **N = 4096** are used. The rationale behind these particular values will be explained later. + +Input data +---------- + +The input will be uniformly distributed random data on the interval [-1, 1]: + +.. code-block:: cpp + + initializeMatrix(A.data(), M, K, -1.0, 1.0); + initializeMatrix(B.data(), N, K, -1.0, 1.0); + +Simple Matmul +============= + +On the AMD **MI300** GPU series (see :ref:`ck_tile_gpu_basics`), each Compute Unit (CU) contains **four SIMD units**. Each SIMD unit can execute a single **wavefront** of 64 threads in parallel. Since there are four wavefronts per CU, a CU can therefore sustain the execution of up to **256 concurrent threads**. + +These 256 threads then can be logically grouped into a **thread block**, which is responsible for computing a **sub-block (tile)** of the output matrix ``C``. A block of 256 threads can be arranged as a **16×16 thread block**, where each thread computes one element of a **16×16 tile** of the result matrix ``C``. Multiple thread blocks are then organized into a **grid**, such that the collection of blocks covers the entire output matrix. + +Consider a baseline matrix multiplication kernel where **each thread computes one output element** of ``C``. The HIP launch configuration can be defined as: + +.. code-block:: cpp + + dim3 blockSizeRef(16, 16); + dim3 gridSizeRef((N + blockSizeRef.x - 1) / blockSizeRef.x, + (M + blockSizeRef.y - 1) / blockSizeRef.y); + + matrixMulHIP<<>>(d_A, d_B, d_C); + +And the GPU Kernel: + +.. code-block:: cpp + + __global__ void matrixMulHIP(s_type * __restrict__ A, + s_type* __restrict__ B, + float* __restrict__ C) + { + // Calculate global thread coordinates in output matrix C + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Boundary check for valid threads + if (row < N && col < N) { + float value = 0.0f; + // Perform the dot product of row from A and column from B + for (int k = 0; k < K; ++k) { + value += A[row * K + k] * B[col * K + k]; + } + // Store computed value in output matrix + C[row * N + col] = value; + } + } + +This kernel has a very low compute throughput according to ``rocprofv3`` profiler output. It is stalling on global memory read transactions effectively starving the rest of the pipeline that needs that data to proceed. + +Memory Bandwidth Analysis +------------------------- + +In a naïve implementation of matrix multiplication, **pressure on global memory loads** quickly becomes the bottleneck. To understand why, it is necessary to look at how a single **16×16 block** of the destination matrix ``C`` is computed by one block of threads within a compute unit. + +Each thread in the block is responsible for computing a single element of ``C``. To do so, it loops over the ``K`` dimension and, in every iteration, fetches **two values** from global memory: + +- one from a row of ``A`` +- one from a column of ``B`` + +This means: + +- Number of threads in a 16×16 block is 256. +- Each thread performs 2K global loads +- **Total global loads** = 256 × 2K = 512K +- **Total global stores** = 256 (one per output element in ``C``) + +To reuse each element of ``A`` and ``B`` perfectly (loading each only once), the unique data required would be: + +- Unique ``A`` elements: 16 × K = 16K +- Unique ``B`` elements: 16 × K = 16K +- **Total unique loads** = 16K + 16K = 32K +- **Total stores** = 256 + +- **Naïve kernel**: 512K global loads + 256 stores +- **Ideal reuse**: 32K global loads + 256 stores + +This illustrates a **16× difference in memory traffic** for the same computation on a small, 16x16 block. + +What is Tiling? +=============== + +Cooperative Loading with LDS +---------------------------- + +In the naïve implementation, threads within the same compute unit (CU) do not cooperate with each other at all. Each thread independently and greedily loads the row elements of ``A`` and the column elements of ``B`` that it needs in order to compute its corresponding value in ``C``. + +Each CU on the MI300 has **64 KB of Local Data Share (LDS)** (see :ref:`ck_tile_lds_bank_conflicts` for optimization techniques) that acts as a shared memory space accessible by all threads in that CU. This opens the possibility of **cooperative loading**. + +Instead of having every thread repeatedly fetch its own data directly from global memory, threads can **collaboratively preload** a block of data into LDS. Once in LDS, this data can be reused by many threads, reducing redundant global memory fetches. + +Entire rows or columns of ``A`` and ``B`` can't be preloaded into LDS, since they might be very large and LDS has a fixed capacity. The solution is to load **small blocks (tiles)** of data at a time. For example: + +- Load a **16×16 tile** from ``A`` and ``B`` into LDS +- Allow all threads in the CU to reuse the data from that tile to compute their portion of the result +- Once done, move the tile window forward along the ``K`` dimension +- Repeat until the entire **16×16 output block** of ``C`` is computed + +This technique of **tiling with cooperative loading** reduces global memory traffic and improves GPU efficiency by leveraging fast, on-chip LDS as in LDS has a better speed and reuse of the data. + +Tiling Mathematics +------------------ + +How many elements of matrices A and B need to be loaded with the tiling approach? + +For a thread block computing a ``TILE_M × TILE_N`` output tile with K-blocking: + +- Elements of **A** loaded per block: + + .. math:: + \text{A\_loads} = \mathrm{TILE\_M} \cdot K + +- Elements of **B** loaded per block: + + .. math:: + \text{B\_loads} = \mathrm{TILE\_N} \cdot K + +- Total outputs produced per block: + + .. math:: + \text{outputs} = \mathrm{TILE\_M} \cdot \mathrm{TILE\_N} + +The **average loads per output element** (ignoring C traffic) are: + +.. math:: + \text{loads per output} = \frac{\mathrm{TILE\_M}\cdot K + \mathrm{TILE\_N}\cdot K}{\mathrm{TILE\_M} \cdot \mathrm{TILE\_N}} = K \left(\frac{1}{\mathrm{TILE\_M}} + \frac{1}{\mathrm{TILE\_N}}\right) + +To simplify the formula, consider a square tile of size T, to compute one value in C: + +- Naïve (no tiling) = 2K loads per output. +- With tiling = 2K/T. +- **Reduction factor = T**. + +Example: T=16 + +.. math:: + \text{loads per output} = \frac{2K}{16} = \frac{K}{8} + +Compared to the naïve 2K, this gives a **16× reduction** in global memory traffic per output element. + +LDS Usage and Tiling Efficiency +------------------------------- + +How much space in LDS would this tiling use? Matrices **A** and **B** store data in **bf16** format. For a small 16×16 tile: + +- Each matrix contains 16 × 16 = 256 elements. +- At 2 bytes per element, each matrix occupies 256 × 2 = 512 bytes. +- Total for A and B: 512 × 2 = 1 KB. + +There is much more space in LDS, so why not try a bigger tile size? 32 KB for each matrix can be used, which allows the tile size to be increased to **256×64**. With this tile size, each compute unit (CU) will output a **256×256 block in C**. With this approach, the number of global memory reads will be **256 times smaller per element in C** compared to a brute-force approach. + +Variation of the GEMM in Inference +---------------------------------- + +When implementing GEMM in inference, because B matrix is the weight which is static, the B matrix will be preshuffled to the warp GEMM MFMA shape to have a faster access for registers to do the MFMA operations. In this strategy there are the following optimizations: + +- Shared Memory bypass of the B Matrix. +- Loop over the A Matrix stored in the shared memory and let B stays in the registers. +- Ping Pong buffering for the GEMM Pipeline + + +Utilization Considerations +-------------------------- + +This section explains why the input dimensions **M = 4864** and **N = 4096** are convenient choices. + +The MI300 has **304 compute units (CUs)**. If a tile size of **256×64** is chosen, where the **K dimension** is iterated over, then the output grid size is: + +.. code-block:: text + + M / 256 × N / 256 = 4864 / 256 × 4096 / 256 = 19 × 16 = 304 + +This matches the total number of compute units on the GPU. That means every CU can be fully occupied with one tile of work, and imbalance or underutilization is not as much of a concern. + +Advanced Optimizations +====================== + +Matrix Fused Multiply-Add +------------------------- + +Because compute-to-memory-access ratio can be a bottleneck, optimizing for bandwidth only isn't enough. + +GPUs offer dedicated **matrix (or tensor) cores** for multiplication tasks. These cores are specifically designed to accelerate matrix operations. + +To take full advantage of these specialized cores, intrinsic instructions can be used. Intrinsic instructions are hardware-specific functions that allow for direct access to the matrix core pipelines. For this example, ``__builtin_amdgcn_mfma_f32_16x16x16f16``, has a low latency of only 16 cycles, will be used. + +16x16 matrices will be used as input, and 16x16 matrices will be used as output. These instructions work as *accumulate add*, what they effectively do is: ``D = A*B + C``. This is useful in this example since results will be accumulated over multiple tiles over K dimension. + +Optimizing Data Flow with Pipelining +------------------------------------ + +To maximize performance, the flow for this kernel uses a **pipeline** or **double buffering** to keep the compute units continuously fed with data, reducing idle time. This pipeline consists of a series of stages that process data concurrently: + +* **Stage 1: Global Memory to Registers:** The first stage involves pre-loading data directly from **global memory** into Vector General Purpose Registers (VGPR). This is the slowest part of the pipeline. Because of this, this operation is performed as early as possible. + +* **Stage 2: Registers to LDS (Shared Memory):** As data is being loaded from global memory, the next stage of the pipeline moves the data from the VGPRs into **LDS (Local Data Share)**, or shared memory. This is an intermediate step that makes the data accessible to all threads within the workgroup at very low latency. + +* **Stage 3: LDS to Registers:** With the data now in LDS, the data is transferred from LDS back into a different set of VGPR registers, which will serve as the direct input for the compute operations. + +* **Stage 4: Computation with MFMA:** The Matrix-FMA (MFMA) intrinsic uses the data from the VGPRs to perform the actual matrix multiplication and accumulation. + +By using this pipelined approach, the different stages of data movement and computation happen in parallel. While the current VGPRs are being consumed by the MFMA operation, the next set of data is already being moved from LDS to another set of VGPRs, and the next tile of data is being loaded from global memory into a third set of VGPRs. This overlapping of operations is key to keeping the GPU's compute units fully utilized. + +CK Tile Implementation +====================== + +Here's how CK Tile implements an optimized GEMM kernel: + +.. code-block:: cpp + + template + __global__ void ck_tile_gemm_kernel(const ADataType* __restrict__ a_global, + const BDataType* __restrict__ b_global, + CDataType* __restrict__ c_global, + index_t M, + index_t N, + index_t K) + { + // Define tile distribution encoding + // See :ref:`ck_tile_encoding_internals` and :ref:`ck_tile_tile_distribution` + using Encoding = tile_distribution_encoding< + sequence<>, // No replication + tuple, // M dimension hierarchy + sequence<4, 2, 8, 4>>, // N dimension hierarchy + tuple, sequence<1, 2>>, // Thread mapping + tuple, sequence<2, 2>>, // Minor indices + sequence<1, 1, 2, 2>, // Y-space mapping + sequence<0, 3, 0, 3> // Y-space minor + >; + + constexpr auto tile_dist = make_static_tile_distribution(Encoding{}); + + // Create tensor views for global memory + // See :ref:`ck_tile_tensor_views` and :ref:`ck_tile_buffer_views` + auto a_global_view = make_naive_tensor_view( + a_global, make_tuple(M, K), make_tuple(K, 1)); + auto b_global_view = make_naive_tensor_view( + b_global, make_tuple(N, K), make_tuple(K, 1)); + auto c_global_view = make_naive_tensor_view( + c_global, make_tuple(M, N), make_tuple(N, 1)); + + // Calculate block offset + const index_t block_m_id = blockIdx.y; + const index_t block_n_id = blockIdx.x; + + // Create tile windows for loading + // See :ref:`ck_tile_tile_window` for tile window details + auto a_window = make_tile_window( + a_global_view, + make_tuple(number{}, number{}), + make_tuple(block_m_id * MPerBlock, 0), + tile_dist); + + auto b_window = make_tile_window( + b_global_view, + make_tuple(number{}, number{}), + make_tuple(block_n_id * NPerBlock, 0), + tile_dist); + + // Allocate LDS storage + // See :ref:`ck_tile_static_distributed_tensor` for distributed tensors + auto a_lds = make_static_distributed_tensor(); + auto b_lds = make_static_distributed_tensor(); + + // Initialize accumulator + auto c_reg = make_static_distributed_tensor(); + // See :ref:`ck_tile_sweep_tile` for sweep operations + sweep_tile(c_reg, [](auto idx, auto& val) { val = 0; }); + + // Main GEMM loop with pipelining + constexpr index_t num_k_tiles = K / KPerBlock; + + // Preload first tile + a_window.load(a_lds); + b_window.load(b_lds); + __syncthreads(); + + // Pipeline loop + for(index_t k_tile = 0; k_tile < num_k_tiles - 1; ++k_tile) { + // Move windows for next iteration + // See :ref:`ck_tile_coordinate_movement` for window movement + a_window.move_slice_window(make_tuple(0, KPerBlock)); + b_window.move_slice_window(make_tuple(0, KPerBlock)); + + // Prefetch next tile while computing current + auto a_lds_next = make_static_distributed_tensor(); + auto b_lds_next = make_static_distributed_tensor(); + + a_window.load_async(a_lds_next); + b_window.load_async(b_lds_next); + + // Compute with current tile + gemm_tile(a_lds, b_lds, c_reg); + + // Wait for prefetch and swap buffers + __syncthreads(); + a_lds = a_lds_next; + b_lds = b_lds_next; + } + + // Last tile computation + gemm_tile(a_lds, b_lds, c_reg); + + // Store result + auto c_window = make_tile_window( + c_global_view, + make_tuple(number{}, number{}), + make_tuple(block_m_id * MPerBlock, block_n_id * NPerBlock), + tile_dist); + + c_window.store(c_reg); + } + + +Key Takeaways +============= + +1. **Tiling is essential**: Reduces memory traffic by orders of magnitude +2. **Use specialized hardware**: MFMA instructions provide massive speedup +3. **Pipeline operations**: Hide memory latency with computation +4. **CK Tile abstractions**: Automatically handle complex optimizations +5. **Hardware-aware dimensions**: Choose problem sizes that map well to CU count + +By understanding these optimization techniques and using CK Tile's high-level abstractions, developers can improve performance onGPUs without manual low-level optimization. + +Related Topics + +- :ref:`ck_tile_tile_distribution` - Core distribution mechanism used in GEMM +- :ref:`ck_tile_tile_window` - Window-based data access patterns +- :ref:`ck_tile_static_distributed_tensor` - LDS memory management for tiles +- :ref:`ck_tile_lds_bank_conflicts` - Avoiding bank conflicts in GEMM +- :ref:`ck_tile_thread_mapping` - How threads map to GEMM computation +- :ref:`ck_tile_load_store_traits` - Optimized memory access patterns +- :ref:`ck_tile_space_filling_curve` - Advanced traversal patterns +- :ref:`ck_tile_sweep_tile` - Iterating over distributed data +- :ref:`ck_tile_gpu_basics` - Understanding the hardware +- :ref:`ck_tile_coordinate_systems` - Mathematical foundation diff --git a/docs/conceptual/ck_tile/hardware/gpu_basics.rst b/docs/conceptual/ck_tile/hardware/gpu_basics.rst new file mode 100644 index 0000000000..c8109c8200 --- /dev/null +++ b/docs/conceptual/ck_tile/hardware/gpu_basics.rst @@ -0,0 +1,38 @@ +.. meta:: + :description: Introduction to AMD CDNA Architecture for CK developers + :keywords: CDNA, RDNA, ROCm, CK, Composable Kernel, GPU architecture, compute units + +.. _ck_tile_gpu_basics: + +******************************************************************** +Intro to AMD CDNA Architecture +******************************************************************** + +The AMD CDNA architecture is a specialized GPU design for high-performance computing (HPC) and AI workloads. Unlike the RDNA architecture used in gaming GPUs, CDNA is optimized for data center tasks, prioritizing compute density, memory bandwidth, and scalability. This is achieved through several key architectural features. + +For more information about the AMD GPU architecture, see the `GPU architecture documentation `_. + +Implications for CK Tile +======================== + +Understanding the CDNA architecture is crucial for effective use of CK Tile: + +1. **Thread Organization**: CK Tile's hierarchical :ref:`ck_tile_thread_mapping` (blocks → warps → threads) directly maps to CDNA's hardware organization. + +2. **Memory Hierarchy**: CK Tile's :ref:`ck_tile_buffer_views` and :ref:`ck_tile_tile_window` are designed to efficiently utilize the L2, Infinity Cache, and LDS hierarchy. + +3. **Register Pressure**: CK Tile's compile-time optimizations help minimize VGPR usage, preventing spills to slower memory. + +4. **Warp Execution**: CK Tile's :ref:`ck_tile_tile_distribution` ensures that threads within a warp access contiguous memory for optimal SIMD execution. + +5. **LDS Utilization**: CK Tile's :ref:`ck_tile_static_distributed_tensor` and :ref:`ck_tile_tile_window` make effective use of the 64KB LDS per CU. + +By understanding these architectural features, developers can better appreciate how CK Tile's abstractions map to hardware capabilities and why certain design decisions were made in the framework. + +Related Topics + +- :ref:`ck_tile_thread_mapping` - How threads are organized and mapped to hardware +- :ref:`ck_tile_coordinate_systems` - Mathematical foundation for data distribution +- :ref:`ck_tile_lds_bank_conflicts` - Optimizing shared memory access patterns +- :ref:`ck_tile_load_store_traits` - Memory access optimization strategies +- :ref:`ck_tile_gemm_optimization` - Practical application of architecture knowledge diff --git a/docs/conceptual/ck_tile/hardware/index.rst b/docs/conceptual/ck_tile/hardware/index.rst new file mode 100644 index 0000000000..d9191c7298 --- /dev/null +++ b/docs/conceptual/ck_tile/hardware/index.rst @@ -0,0 +1,127 @@ +.. meta:: + :description: CK Tile Hardware-Specific Documentation + :keywords: CDNA, GPU architecture, LDS, GEMM, CK, Composable Kernel + +.. _ck_tile_hardware: + +******************************************************************** +CK Tile Hardware Documentation +******************************************************************** + +This section provides in-depth coverage of hardware-specific concepts and optimizations for CK Tile on AMD GPUs. + +Overview +======== + +Understanding the underlying hardware architecture is crucial for achieving optimal performance with CK Tile. This documentation covers: + +- AMD CDNA architecture fundamentals +- Memory hierarchy and optimization techniques +- Practical examples of high-performance kernels + +Documentation Structure +======================= + +.. toctree:: + :maxdepth: 2 + :caption: Hardware Topics + + gpu_basics + lds_bank_conflicts + gemm_optimization + +GPU Architecture Basics +----------------------- + +:ref:`ck_tile_gpu_basics` provides an introduction to AMD CDNA architecture. + +LDS and Bank Conflicts +---------------------- + +:ref:`ck_tile_lds_bank_conflicts` explains Local Data Share (LDS) optimization. + +GEMM Optimization Case Study +---------------------------- + +:ref:`ck_tile_gemm_optimization` demonstrates a complete optimization journey. + + +Key Hardware Considerations +=========================== + + +Memory Hierarchy +---------------- + +1. **Global Memory**: High latency, high bandwidth + + - Optimize with coalesced access patterns + - Use tile windows for automatic optimization + +2. **L2/Infinity Cache**: Intermediate storage + + - Benefits from spatial and temporal locality + - CK Tile's tiling naturally improves cache hit rates + +3. **LDS**: Low latency, shared within CU + + - 64KB per CU, organized in 32 banks + - CK Tile handles bank conflict avoidance + +4. **Registers**: Lowest latency, per-thread storage + + - 512 VGPRs available per wavefront + - CK Tile's compile-time optimization minimizes usage + +Compute Resources +----------------- + +1. **Wavefront Execution**: 64 threads in lockstep + + - CK Tile ensures coalesced memory access + - Automatic warp-level synchronization + +2. **Matrix Units**: Specialized MFMA instructions + + - 16x16x16 operations in 16 cycles + - CK Tile can leverage these automatically + +3. **Occupancy**: Balancing threads vs resources + + - Register pressure affects occupancy + - CK Tile helps through efficient register use + +Performance Guidelines +====================== + +To achieve optimal performance with CK Tile: + +1. **Choose appropriate tile sizes**: + + - Match hardware capabilities (e.g., 256x256 for GEMM) + - Consider LDS capacity and register pressure + +2. **Align problem dimensions**: + + - Match CU count when possible (304 for MI300) + - Use padding for non-aligned sizes + +3. **Enable pipelining**: + + - Use double buffering for latency hiding + - CK Tile supports async operations + +4. **Profile and verify**: + + - Use rocprof to check for bottlenecks + - Verify bank conflict avoidance + - Monitor occupancy and register usage + +Next Steps +========== + +- Review :ref:`ck_tile_gpu_basics` for architecture fundamentals +- Study :ref:`ck_tile_lds_bank_conflicts` for shared memory optimization +- Explore :ref:`ck_tile_gemm_optimization` for a complete optimization example + +For practical implementation, refer back to the main :ref:`ck_tile_conceptual` documentation to see how these hardware concepts integrate with CK Tile's abstractions. diff --git a/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst b/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst new file mode 100644 index 0000000000..8802fba9e8 --- /dev/null +++ b/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst @@ -0,0 +1,209 @@ +.. meta:: + :description: Understanding AMD GPU LDS and Bank Conflicts in CK Tile + :keywords: LDS, bank conflicts, shared memory, CK, Composable Kernel, GPU optimization + +.. _ck_tile_lds_bank_conflicts: + +******************************************************************** +Understanding AMD GPU LDS and Bank Conflicts +******************************************************************** + +Introduction +============ + +Local Data Share (**LDS**) is AMD's shared memory within a compute unit (see :ref:`ck_tile_gpu_basics` for architecture details). It is organized into **32 or 64 banks** depending on the hardware architecture, each bank has a 4 bytes width. Understanding how memory addresses map to banks is key to avoiding **bank conflicts**. + +Bank Mapping +============ + +For AMD GCN architecture, the LDS bank mapping is typically: + +.. math:: + + \text{bank} = \left( \frac{\text{address in bytes}}{4} \right) \bmod 32 + +This means: + +- Addresses that differ by multiples of ``bank numbers * 4 bytes`` map to the same bank. +- Conflicts occur when multiple threads in the same wave access the same bank **in the same cycle**. + +Not all the lanes can produce bank conflicts. HW divides access to LDS from wavefront into phases. Which lanes would be considered in each phase depends on the width of the instruction. Let us consider ``ds_write_b128`` as an example as it is the instruction that has the largest granularity write with the highest performance. Here access will be divided into 8 phases for 64 lane wavefront. If in 1 phase there will not be two thread access the same bank, there will bot be bank conflict: + +- lane0~lane7 +- lane8~lane15 +- lane16~lane23 +- lane24~lane31 +- lane32~lane39 +- lane40~lane47 +- lane48~lane55 +- lane56~lane63 + +If within each group of lanes there is no conflict it is an LDS bank conflict free write access. + +Bank Access Patterns +==================== + +LDS bank access can be simulated for a given set of thread addresses. With a 32 bank LDS with 4 bytes per bank, each thread will be writing 8 2-byte elements (16 bytes total), consuming 4 banks in LDS. fp16 or bf16 are the common formats GPU kernels have to deal with. With the phase access pattern like above by default it is a bank conflict free LDS write access. + +Write Access Pattern +-------------------- + +For LDS write instructions like ``ds_write_b128``, the hardware provides conflict-free access when threads write to consecutive addresses. Each phase of 8 lanes writes to different banks, avoiding conflicts. + +Read Access Pattern +------------------- + +Similarly for LDS read instruction ``ds_read_b128``, when there is no bank conflict in these 8 lane groups: + +- 0:3+20:23 +- 4:7+16:19 +- 8:11+28:31 +- 12:15+24:27 +- 32:35+52:55 +- 36:39+48:51 +- 40:43+60:63 +- 44:47+56:59 + +then it's bank conflict-free for LDS reading. + +The reason for accessing the data vertically is because in most LDS access the MFMA instruction in the next step and the MFMA are requirde to access the data vertically like above. + +The LDS read access pattern illustrated below is typical for LDS usage in machine learning workloads. The read pattern can generate 4-way bank conflicts in every phase of access. You can experiment with ``row_padding`` (padding in a number of banks) to see if the problem can be solved this way, but also remember that in practice this will require additional LDS storage. The bigger the padding, the more additional storage is necessary. + +XOR Preshuffle: An Alternative to Padding +========================================= + +Another technique to reduce LDS bank conflicts is **XOR preshuffling** (see :ref:`ck_tile_lds_index_swapping` for detailed implementation). Instead of adding padding between rows, we can permute the column indices for each row using XOR. This method can help to avoid bank conflicts without allocating extra storage in LDS. + +For a wavefront of 64 threads, if each thread writes a vector of 8 fp16 elements (16 bytes), and the row size is 64 elements, the column index for each element in a row is adjusted as follows: + +- ``KTypeSize = 2`` +- ``KPerBlock = 64`` // 64 elements per row +- ``KPack = 8`` // 8 elements per thread + +The adjusted column position for element ``(x, y)`` is: + +.. math:: + + x' = \left( y \bmod \frac{\text{KPerBlock}}{\text{KPack}} \right) \oplus x + +where :math:`\oplus` is the bitwise XOR, and :math:`x, y` are the original positions of a vector element with respect to the LDS banks. + +C++ Implementation +================== + +Here's how CK implements XOR preshuffling: + +.. code-block:: cpp + + // XOR-based column index adjustment + template + __device__ constexpr index_t xor_preshuffle(index_t row, index_t col) + { + constexpr index_t num_cols = KPerBlock / KPack; + return (row % num_cols) ^ col; + } + + // LDS write with XOR preshuffle + template + __device__ void lds_write_with_xor(DataType* lds_ptr, + const DataType* src, + index_t row, + index_t col) + { + // Apply XOR preshuffle to column index + index_t col_xor = xor_preshuffle<64, 8>(row, col); + + // Write to LDS with adjusted column + index_t offset = row * RowStride + col_xor * 8; + + // Vectorized write (assuming 128-bit write) + *reinterpret_cast(lds_ptr + offset) = + *reinterpret_cast(src); + } + + // LDS read with XOR preshuffle + template + __device__ void lds_read_with_xor(DataType* dst, + const DataType* lds_ptr, + index_t row, + index_t col) + { + // Apply same XOR preshuffle for read + index_t col_xor = xor_preshuffle<64, 8>(row, col); + + // Read from LDS with adjusted column + index_t offset = row * RowStride + col_xor * 8; + + // Vectorized read + *reinterpret_cast(dst) = + *reinterpret_cast(lds_ptr + offset); + } + +Integration with CK Tile +======================== + +CK Tile handles LDS bank conflict avoidance through its abstractions: + +1. **TileWindow** (:ref:`ck_tile_tile_window`): Automatically applies XOR preshuffling when loading/storing to LDS +2. **StaticDistributedTensor** (:ref:`ck_tile_static_distributed_tensor`): Manages LDS allocation with proper alignment +3. **LoadStoreTraits** (:ref:`ck_tile_load_store_traits`): Selects optimal access patterns to minimize conflicts + +Example usage in CK Tile: + +.. code-block:: cpp + + // CK Tile automatically handles bank conflict avoidance + template + __device__ void gemm_kernel() + { + // Create tile window with automatic XOR preshuffle + auto a_window = make_tile_window( + a_tensor_view, + tile_size, + origin, + tile_distribution); + + // Load to LDS - XOR preshuffle applied automatically + auto a_lds_tensor = make_static_distributed_tensor< + element_type, + decltype(tile_distribution)>(); + + a_window.load(a_lds_tensor); + + // Subsequent reads from LDS are conflict-free + // See :ref:`ck_tile_sweep_tile` for sweep operations + sweep_tile(a_lds_tensor, [](auto idx, auto& val) { + // Process data... + }); + } + +Performance Impact +================== + +Proper LDS bank conflict avoidance can have significant performance impact: + +- **4-way conflicts**: Can reduce effective LDS bandwidth by 75% +- **XOR preshuffle**: Restores full bandwidth with zero storage overhead +- **Padding**: Also effective but requires 12.5-25% more LDS storage + +Best Practices +============== + +1. **Use CK Tile abstractions**: They automatically handle bank conflict avoidance +2. **Prefer XOR preshuffle**: No storage overhead compared to padding +3. **Verify with profiling**: Use rocprof to check for LDS bank conflicts +4. **Consider access patterns**: Design algorithms with bank-friendly patterns + +By understanding LDS bank conflicts and using CK Tile's automatic conflict avoidance mechanisms, developers can achieve optimal shared memory performance without manual optimization. + +Related Topics +============== + +- :ref:`ck_tile_lds_index_swapping` - Detailed XOR preshuffle implementation +- :ref:`ck_tile_swizzling_example` - Morton ordering for memory swizzling +- :ref:`ck_tile_gpu_basics` - Understanding AMD GPU architecture +- :ref:`ck_tile_tile_window` - Automatic conflict avoidance in data access +- :ref:`ck_tile_static_distributed_tensor` - LDS memory management +- :ref:`ck_tile_gemm_optimization` - Practical application in GEMM kernels +- :ref:`ck_tile_transforms` - Coordinate transformations for conflict avoidance diff --git a/docs/conceptual/ck_tile/index.rst b/docs/conceptual/ck_tile/index.rst new file mode 100644 index 0000000000..287143d6de --- /dev/null +++ b/docs/conceptual/ck_tile/index.rst @@ -0,0 +1,108 @@ +.. _ck_tile_conceptual: + +CK Tile Conceptual Documentation +================================ + +Welcome to the conceptual documentation for CK Tile, the core abstraction layer of Composable Kernel that enables efficient GPU programming through compile-time coordinate transformations and tile-based data distribution. + +See the :ref:`ck_tile_index` for the complete CK Tile documentation structure. + +Overview +-------- + +CK Tile provides a mathematical framework for expressing complex GPU computations through: + +- **Automatic Memory Coalescing**: Ensures optimal memory access patterns without manual optimization +- **Thread Cooperation**: Coordinates work distribution across the GPU's hierarchical execution model +- **Zero-Overhead Abstractions**: Compile-time optimizations ensure no runtime performance penalty +- **Portable Performance**: Same code achieves high performance across different GPU architectures + +Why CK Tile? +------------ + +Traditional GPU programming requires manual management of: + +- Thread-to-data mapping calculations +- Memory coalescing patterns +- Bank conflict avoidance +- Boundary condition handling + +CK Tile automates all of these concerns through a unified abstraction that maps logical problem coordinates to physical GPU resources. + + +Learning Path +------------- + +1. **Start Here**: :ref:`ck_tile_introduction` + + The fundamental problems CK Tile solves and why it's essential for efficient GPU programming. + +2. **Foundation**: :ref:`ck_tile_buffer_views` + + How CK Tile provides structured access to raw GPU memory across different address spaces. + +3. **Multi-Dimensional Views**: :ref:`ck_tile_tensor_views` + + How to work with multi-dimensional data structures and memory layouts. + +4. **Core API**: :ref:`ck_tile_distribution` + + The tile distribution system that maps work to GPU threads. + +5. **Mathematical Framework**: :ref:`ck_tile_coordinate_systems` + + The coordinate transformation system that powers CK Tile's abstractions. + +6. **Reference**: :ref:`ck_tile_terminology` + + Glossary of all terms and concepts used in CK Tile. + + +Key Concepts at a Glance +------------------------ + +**Coordinate Spaces** + +- **P-space**: Processing element coordinates (thread, warp, block) +- **Y-space**: Local tile access patterns +- **X-space**: Physical tensor coordinates +- **D-space**: Linearized memory addresses + +**Core Components** + +- **BufferView**: Type-safe access to GPU memory +- **TileDistribution**: Automatic work distribution +- **TileWindow**: Efficient data loading/storing +- **Encoding**: Compile-time distribution specification + +Quick Example +------------- + +.. code-block:: cpp + + // Define how to distribute a 256x256 tile across threads + using Encoding = tile_distribution_encoding< + sequence<>, // No replication + tuple, // M dimension hierarchy + sequence<4,2,8,4>>, // N dimension hierarchy + tuple, sequence<1,2>>, // Thread mapping + tuple, sequence<2,2>>, // Minor indices + sequence<1,1,2,2>, // Y-space mapping + sequence<0,3,0,3> // Y-space minor + >; + + // Create distribution and load data + auto distribution = make_static_tile_distribution(Encoding{}); + auto window = make_tile_window(tensor_view, tile_size, origin, distribution); + auto tile = window.load(); + + // Process tile efficiently + sweep_tile(tile, [](auto idx) { /* computation */ }); + + +Next Steps +---------- + +To dive deeper, start with :ref:`ck_tile_introduction` to understand the motivation and core concepts behind CK Tile. + +For practical examples, see the `example/ck_tile `_ directory in the Composable Kernel repository. diff --git a/docs/conceptual/ck_tile/introduction_motivation.rst b/docs/conceptual/ck_tile/introduction_motivation.rst new file mode 100644 index 0000000000..9884901556 --- /dev/null +++ b/docs/conceptual/ck_tile/introduction_motivation.rst @@ -0,0 +1,309 @@ +.. _ck_tile_introduction: + +Introduction and Motivation - Why Tile Distribution Matters +=========================================================== + +Overview +-------- + +The evolution of GPU computing has brought unprecedented computational power to modern applications, yet harnessing this power efficiently remains one of the most challenging aspects of high-performance computing. At the heart of this challenge lies a fundamental mismatch between how developers conceptualize algorithms and how GPU hardware executes them. While developers think in terms of mathematical operations on multi-dimensional data structures, GPUs operate through thousands of threads accessing memory in complex patterns that must satisfy stringent hardware constraints. + +This conceptual gap manifests most acutely in memory access patterns. Modern GPUs achieve their high performance through massive parallelism, with thousands of threads executing simultaneously. However, this parallelism comes with a critical constraint: memory bandwidth. Despite continuous improvements in computational throughput, memory bandwidth has not scaled proportionally, creating what is often called the "memory wall." The efficiency with which threads access memory determines whether a GPU kernel achieves a few percent or near 100% of the hardware's theoretical performance. + +The Composable Kernel (CK) framework addresses this challenge through its tile distribution system, a compile-time abstraction that automatically generates optimal memory access patterns while preserving the natural expression of algorithms. This documentation explores the mathematical foundations and practical implementation of tile distribution, demonstrating how it bridges the gap between algorithmic intent and hardware reality. + +In this introduction, we establish the fundamental problems that tile distribution solves, explore why these problems are critical for GPU performance, and provide the conceptual framework necessary to understand the compile-time coordinate transformation system that powers CK's approach to efficient GPU computation. + +The GPU Memory Problem +---------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Random Access Pattern (Inefficient)" + subgraph "Threads" + T0_R["Thread 0"] + T1_R["Thread 1"] + T2_R["Thread 2"] + T3_R["Thread 3"] + end + + subgraph "Memory" + M0["Mem[0]"] + M7["Mem[7]"] + M15["Mem[15]"] + M23["Mem[23]"] + M31["Mem[31]"] + M39["Mem[39]"] + M47["Mem[47]"] + M55["Mem[55]"] + end + + T0_R -.-> M23 + T1_R -.-> M7 + T2_R -.-> M47 + T3_R -.-> M15 + end + + subgraph "Tile Distribution Pattern (Efficient)" + subgraph "Threads_TD" + T0_TD["Thread 0"] + T1_TD["Thread 1"] + T2_TD["Thread 2"] + T3_TD["Thread 3"] + end + + subgraph "Memory_TD" + M0_TD["Mem[0]"] + M1_TD["Mem[1]"] + M2_TD["Mem[2]"] + M3_TD["Mem[3]"] + M4_TD["Mem[4]"] + M5_TD["Mem[5]"] + M6_TD["Mem[6]"] + M7_TD["Mem[7]"] + end + + T0_TD --> M0_TD + T0_TD --> M1_TD + T1_TD --> M2_TD + T1_TD --> M3_TD + T2_TD --> M4_TD + T2_TD --> M5_TD + T3_TD --> M6_TD + T3_TD --> M7_TD + end + + style T0_R fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style T1_R fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style T2_R fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style T3_R fill:#fee2e2,stroke:#ef4444,stroke-width:2px + + style T0_TD fill:#d1fae5,stroke:#10b981,stroke-width:2px + style T1_TD fill:#d1fae5,stroke:#10b981,stroke-width:2px + style T2_TD fill:#d1fae5,stroke:#10b981,stroke-width:2px + style T3_TD fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + +.. image:: diagrams/introduction_motivation_1.svg + :alt: Diagram + :align: center + +Why Random Memory Access is Slow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The architecture of modern GPUs represents a study in trade-offs. While these devices can execute thousands of threads simultaneously and perform trillions of floating-point operations per second, they remain fundamentally constrained by the physics of memory access. Understanding this constraint is crucial to appreciating why tile distribution is not merely an optimization technique but an essential component of high-performance GPU computing. + +GPU memory systems are designed around the assumption of regular, predictable access patterns. The memory controller can service requests from 32 threads (a warp on AMD GPUs) in a single transaction when these threads access consecutive memory locations. This optimization, known as memory coalescing, can improve effective memory bandwidth by up to 32x compared to random access patterns. However, when threads within a warp access memory locations that are scattered throughout the address space, each access requires a separate memory transaction, reducing the effective bandwidth to a fraction of the theoretical maximum. + +The impact extends beyond raw bandwidth. Modern GPUs employ cache hierarchies to reduce memory latency, but these caches are effective only when access patterns exhibit spatial or temporal locality. Random access patterns defeat these optimizations, causing frequent cache misses that expose the full latency of global memory access, which can be hundreds of cycles. During these stalls, the computational units sit idle, unable to hide the latency even with the GPU's massive thread count. + +Furthermore, the GPU's Single Instruction, Multiple Thread (SIMT) execution model requires that all threads in a warp execute the same instruction at the same time. When threads access memory in unpredictable patterns, the memory controller cannot optimize the requests, leading to serialization of what should be parallel operations. This serialization effect compounds with each level of the memory hierarchy, from L1 cache through L2 cache to global memory, multiplying the performance impact. + +The Thread Cooperation Challenge +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The challenge of efficient thread cooperation becomes particularly evident when examining a fundamental operation like matrix multiplication. Consider a scenario where 256 threads must cooperate to multiply two matrices. The naive approach, where each thread computes one element of the output matrix, illustrates precisely why GPU programming requires compile-time abstractions. + +.. code-block:: cpp + + // Inefficient: Random access pattern + __device__ void naive_matrix_multiply() + { + int thread_id = threadIdx.x + blockIdx.x * blockDim.x; + + // Get this thread's output position + int row = thread_id / MATRIX_WIDTH; + int col = thread_id % MATRIX_WIDTH; + + // Each thread computes one element of C = A * B + float result = 0.0f; + for (int k = 0; k < MATRIX_WIDTH; k++) + { + // Random access pattern - threads in a warp access non-contiguous memory + // Thread 0: A[0,0], A[0,1], A[0,2]... + // Thread 1: A[1,0], A[1,1], A[1,2]... + // These are far apart in memory! + float a_element = global_memory_A[row * MATRIX_WIDTH + k]; + + // Even worse for B - accessing column-wise causes strided access + // Thread 0: B[0,0], B[1,0], B[2,0]... + // Thread 1: B[0,1], B[1,1], B[2,1]... + // Massive stride between accesses! + float b_element = global_memory_B[k * MATRIX_WIDTH + col]; + + result += a_element * b_element; + } + + // Write result - adjacent threads write to adjacent locations (at least this is good) + global_memory_C[row * MATRIX_WIDTH + col] = result; + } + +This seemingly straightforward implementation suffers from fundamental inefficiencies that stem from the mismatch between the algorithm's logical structure and the hardware's physical constraints. The memory access pattern is essentially random from the hardware's perspective, as adjacent threads access memory locations separated by large strides. This pattern prevents the memory controller from coalescing accesses, forcing it to issue separate transactions for each thread. + +The lack of coordination between threads exacerbates the problem. While all threads in a warp execute the same instruction, they operate on completely different data with no sharing or reuse. This independence, which might seem desirable in traditional parallel programming, actually works against GPU architecture. The hardware cannot exploit any commonality in the access patterns, leading to severe underutilization of memory bandwidth. + +Cache utilization suffers dramatically under this access pattern. Each thread traces a unique path through memory, with no overlap between threads' working sets. The L1 and L2 caches, designed to capture and exploit locality, instead thrash continuously as each thread's accesses evict data needed by others. The effective cache capacity approaches zero, exposing every memory access to the full latency of global memory. + +Perhaps most critically, this approach fails to utilize the available memory bandwidth efficiently. Modern GPUs can achieve memory bandwidths exceeding 1 TB/s, but only when accesses are properly structured. The random access pattern of the naive implementation might achieve less than 10% of this theoretical maximum, effectively reducing a high-performance GPU to the performance level of a much simpler processor. + +The Tile Distribution Solution +------------------------------ + +Structured Mapping from Logical to Physical Coordinates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The fundamental innovation of tile distribution lies in its approach to the memory access problem. Rather than attempting to optimize the naive access patterns after the fact, tile distribution provides a mathematical framework that generates optimized patterns from the outset. This framework establishes a structured mapping between logical coordinates and physical coordinates that respect hardware constraints. + +The essence of tile distribution is the recognition that efficient GPU computation requires a careful choreography of thread cooperation. Instead of each thread operating independently, threads are organized into hierarchical groups that work together on tiles of data. This organization ensures that when threads access memory, they do so in patterns that the hardware can optimize. + +.. code-block:: cpp + + // Efficient: Tile-based distribution using CK Tile + template + __device__ void tile_distributed_matrix_multiply() + { + // 1. Define tile distribution encoding at compile time + using Encoding = tile_distribution_encoding< + sequence<>, // No replication + tuple, // M dimension hierarchy + sequence<4, 2, 8, 4>>, // N dimension hierarchy + tuple, sequence<1, 2>>, // P to RH major + tuple, sequence<2, 2>>, // P to RH minor + sequence<1, 1, 2, 2>, // Y to RH major + sequence<0, 3, 0, 3> // Y to RH minor + >; + + // 2. Create the distribution + constexpr auto distribution = make_static_tile_distribution(Encoding{}); + + // 3. Create tile window for efficient memory access + auto tile_window = make_tile_window( + tensor_view, + window_lengths, + origin, + distribution + ); + + // 4. Load data with coalesced access pattern + auto loaded_tensor = tile_window.load(); + + // 5. Process tile data efficiently + sweep_tile(loaded_tensor, [](auto y_indices) { + auto value = loaded_tensor(y_indices); + // ... efficient computation + }); + } + +The transformation from inefficient to efficient memory access is profound. Where the naive implementation scattered memory requests across the address space, tile distribution ensures that adjacent threads access adjacent memory locations. This transformation happens through an advanced encoding system that captures the hierarchical nature of both the computation and the hardware. + +The encoding shown above demonstrates the multi-level hierarchy that tile distribution employs. The sequence<4, 2, 8, 4> represents a four-level decomposition: four repetitions per thread, two warps per block, eight threads per warp, and four elements per vector operation. This hierarchical structure maps directly to the GPU's hardware organization, ensuring that each level of the hierarchy operates at maximum efficiency. + +Memory access patterns become predictable and regular under tile distribution. The hardware's memory coalescing logic can now combine the requests from all threads in a warp into a single transaction, achieving the full memory bandwidth. The predictability extends beyond individual accesses to entire access sequences, enabling the hardware's prefetching mechanisms to anticipate and prepare data before it's needed. + +Thread cooperation emerges naturally from the tile distribution structure. Threads within a warp work on adjacent data, enabling efficient data sharing through register shuffle operations. Warps within a block coordinate through shared memory, with access patterns that avoid bank conflicts. This cooperation transforms what was a collection of independent computations into a unified, efficient operation. + +Cache utilization improves as well. The structured access patterns ensure that data loaded into cache by one thread is likely to be used by neighboring threads. Temporal locality emerges from the tile-based processing, where all operations on a tile complete before moving to the next tile. This locality transforms the cache from a liability into a high performance accelerator. + +The scalability of tile distribution across different GPU architectures represents one of its most key features. The same high-level code can achieve near-optimal performance on GPUs with different numbers of compute units, different cache sizes, and different memory bandwidths. The compile-time nature of the encoding allows the compiler to generate architecture-specific optimizations while maintaining portable source code. + +The Coordinate Mapping Insight +------------------------------ + +At the heart of tile distribution lies a profound mathematical insight: efficient GPU computation requires a systematic framework for mapping between different coordinate spaces. This framework transforms the complex problem of thread-to-data assignment into a series of well-defined mathematical transformations, each serving a specific purpose in the journey from abstract algorithm to concrete hardware execution. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Coordinate Spaces" + P["P-space
Thread Position
(thread_x, thread_y,
warp_id, block_id)"] + Y["Y-space
Local Data
(y0, y1, y2, y3)"] + X["X-space
Global Position
(x0, x1)"] + D["D-space
Memory Address
(linearized)"] + end + + subgraph "Transformations" + T1["P + Y → X
Thread data mapping"] + T2["X → D
Memory linearization"] + end + + P --> T1 + Y --> T1 + T1 --> X + X --> T2 + T2 --> D + + style P fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style Y fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style D fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + style T1 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + style T2 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + + + +.. image:: diagrams/introduction_motivation_2.svg + :alt: Diagram + :align: center + +The elegance of this approach emerges from its separation of concerns. Each coordinate space represents a distinct aspect of the computation, and the transformations between them encapsulate specific optimization strategies. This separation allows developers to reason about their algorithms in natural terms while the framework handles the complex mapping to efficient hardware execution patterns. + +**Thread Position Space (P-space)** represents the physical organization of threads on the GPU. This space captures the hierarchical nature of GPU execution, from individual threads identified by their x and y coordinates within a block, to warps that execute in lockstep, to thread blocks that share resources. The coordinates in P-space—thread_x, thread_y, warp_id, and block_id—directly correspond to the hardware's execution model. Understanding P-space is crucial because it determines which threads can cooperate efficiently through shared memory and which threads will execute their memory accesses simultaneously. + +**Local Data Space (Y-space)** embodies the algorithm's perspective on data organization. In this space, each thread reasons about its local portion of work using coordinates like y0, y1, y2, and y3. These coordinates are algorithm-specific and represent the natural way to index the data being processed. For matrix multiplication, Y-space might represent the local tile coordinates within a larger matrix. For convolution, it might represent the spatial dimensions and channels of a local receptive field. The beauty of Y-space is that it allows algorithms to be expressed in their most natural form, without concern for hardware-specific optimizations. + +**Global Position Space (X-space)** serves as the bridge between algorithmic intent and physical reality. This space represents the actual global coordinates of data in the problem domain, such as the row and column indices in a matrix or the spatial coordinates in an image. X-space is where the distributed nature of the computation becomes explicit, as each thread's local Y-space coordinates combine with its position in P-space to determine which global data elements it accesses. + +**Memory Address Space (D-space)** represents the final destination: linearized memory addresses that the hardware actually uses. This space accounts for the fact that multi-dimensional data structures must ultimately be stored in linear memory. The transformation to D-space incorporates layout optimizations such as padding for alignment, interleaving for better cache utilization, and address space considerations for different memory types (global, shared, or constant memory). + +The transformative power of tile distribution emerges from the composition of these mappings. The **P + Y → X** transformation combines a thread's position with its local data coordinates to determine global data positions. This transformation encodes the distribution strategy, determining how work is partitioned across threads. The subsequent **X → D** transformation converts these logical positions into physical memory addresses, incorporating layout optimizations that ensure efficient memory access patterns. + +The mathematical rigor of this framework enables critical optimizations. Because each transformation is well-defined and composable, the compiler can analyze the complete transformation chain and generate optimal code. The framework can automatically ensure memory coalescing by structuring the P + Y → X transformation appropriately. It can minimize bank conflicts in shared memory by carefully designing the X → D mapping. Most importantly, it can adapt these optimizations to different hardware architectures by adjusting the transformation parameters while keeping the high-level algorithm description unchanged. + +What's Coming Next +------------------ + +Having established the fundamental motivation for tile distribution and its coordinate mapping framework, this documentation now embarks on a systematic journey through the complete CK Tile system. This journey is carefully structured to build understanding layer by layer, starting from the most basic abstractions and progressing to advanced optimization techniques. + +The foundation of the exploration begins with raw memory access through :ref:`ck_tile_buffer_views`, the fundamental abstraction that provides type-safe, address-space-aware access to GPU memory. Understanding BufferView is crucial because it establishes the patterns and principles that permeate the entire CK Tile system. From there, it progresses to :ref:`ck_tile_tensor_views`, which adds multi-dimensional structure to raw memory, enabling natural expression of algorithms while maintaining the efficiency of the underlying buffer operations. + +With these foundational concepts established, the documentation delves into the :ref:`ck_tile_coordinate_systems` that powers tile distribution. This engine implements the mathematical framework that have been introduced, providing compile-time transformations between P-space, Y-space, X-space, and D-space. Understanding these transformations at a deep level enables developers to reason about performance implications and design custom distribution strategies for novel algorithms. The :ref:`ck_tile_transforms` and :ref:`ck_tile_adaptors` provide the building blocks for these transformations. + +The high-level :ref:`ck_tile_distribution` APIs represent the culmination of these lower-level abstractions. These APIs provide an accessible interface for common patterns while exposing enough flexibility for advanced optimizations. Through concrete examples and detailed explanations, the documentation will demonstrate how to leverage these APIs to achieve near-optimal performance across a variety of computational patterns. The :ref:`ck_tile_window` abstraction provides the gateway for efficient data access. + +The exploration of coordinate systems goes beyond the basic P, Y, X, D framework to encompass advanced topics such as multi-level tiling, replication strategies, and specialized coordinate systems for specific algorithm classes. The :ref:`ck_tile_encoding_internals` reveals the mathematical foundations, while :ref:`ck_tile_thread_mapping` shows how these abstractions map to hardware. This comprehensive treatment ensures that developers can handle not just common cases but also novel algorithms that require custom distribution strategies. + +The implementation details reveal the template metaprogramming techniques that enable CK Tile's zero-overhead abstractions. Topics like :ref:`ck_tile_descriptors`, :ref:`ck_tile_load_store_traits`, and :ref:`ck_tile_static_distributed_tensor` show how these abstractions achieve zero overhead. By understanding these implementation strategies, advanced developers can extend the framework, contribute optimizations, and debug performance issues at the deepest level. + +The connection between abstract coordinate transformations and concrete hardware thread mapping represents a critical piece of the puzzle. The documentation will examine how logical thread organizations map to physical GPU resources, how to avoid common pitfalls like bank conflicts (see :ref:`ck_tile_lds_bank_conflicts` and :ref:`ck_tile_lds_index_swapping`) and divergent execution, and how to structure computations for maximum hardware utilization. The :ref:`ck_tile_hardware` section provides deep dives into architecture-specific optimizations. + +Finally, the advanced topics section explores cutting-edge optimization techniques, including :ref:`ck_tile_space_filling_curve` for optimal memory traversal, :ref:`ck_tile_sweep_tile` for clean iteration patterns, and practical examples like :ref:`ck_tile_convolution_example` and :ref:`ck_tile_gemm_optimization`. These topics prepare developers to push the boundaries of GPU performance and contribute to the ongoing evolution of high-performance computing. + +Summary +------- + +The journey through this introduction has revealed tile distribution as a fundamental paradigm shift in how GPU programming is approached. By establishing a mathematical framework for coordinate transformation, tile distribution bridges the gap between algorithmic elegance and hardware efficiency. + +The significance of this approach extends beyond mere performance optimization. Tile distribution enables developers to express algorithms in their natural mathematical form while achieving performance that approaches the theoretical limits of the hardware. This reconciliation of abstraction and efficiency has been a goal of high-performance computing, and tile distribution provides a step towards this goal. + +The structured, predictable mappings between logical and physical coordinates that tile distribution provides yield multiple benefits. Efficient memory access emerges naturally from the framework, with coalesced access patterns and cache-friendly layouts arising from the mathematical structure rather than manual optimization. Thread cooperation becomes an inherent property of the system, with the distribution encoding automatically organizing threads into efficient collaborative patterns. The scalability across different hardware architectures demonstrates the power of abstraction—the same high-level code achieves near-optimal performance whether running on a small mobile GPU or a massive datacenter accelerator. + +Perhaps most importantly, tile distribution provides a predictable optimization framework grounded in mathematical principles. Performance characteristics can be analyzed and predicted based on the transformation structure, enabling systematic optimization rather than trial-and-error tuning. This predictability transforms GPU optimization from an art practiced by a few experts into a science accessible to a broader community of developers. + +The systematic mapping through P-space, Y-space, X-space, and D-space provides a mental model that clarifies the entire GPU computation process. This model enables developers to reason about their code at multiple levels of abstraction simultaneously, understanding both the high-level algorithmic behavior and the low-level hardware execution patterns. + +As the documentation dives deeper into the implementation details, starting with the foundational BufferView abstraction, it is important to remember that each component serves the larger purpose of enabling efficient, scalable GPU computation. The journey from raw memory to advanced tile distributions mirrors the evolution of GPU programming itself, from low-level, hardware-specific optimizations to high-level, portable abstractions that preserve efficiency. + +By providing a framework for achieving optimal memory access patterns, tile distribution enables developers to take advantage of the computing power of GPUs without having to know the details of the underlying architecture. + +Next Steps +---------- + +Continue to :ref:`ck_tile_buffer_views` to start building your understanding from the ground up. diff --git a/docs/conceptual/ck_tile/lds_index_swapping.rst b/docs/conceptual/ck_tile/lds_index_swapping.rst new file mode 100644 index 0000000000..891b32f9ed --- /dev/null +++ b/docs/conceptual/ck_tile/lds_index_swapping.rst @@ -0,0 +1,462 @@ +.. meta:: + :description: CK Tile LDS index swapping documentation + :keywords: CK Tile, LDS, index swapping, XOR preshuffle, bank conflicts, GPU optimization + +.. _ck_tile_lds_index_swapping: + +******************************** +Load Datat Share Index Swapping +******************************** + +Overview +======== + +Local Data Share (LDS) index swapping, also known as XOR preshuffle, is a critical optimization technique in CK Tile for resolving bank conflicts in shared memory. Bank conflicts occur when multiple threads in a warp attempt to access different addresses within the same memory bank simultaneously, causing serialization and performance degradation. CK Tile generalizes the XOR preshuffle technique through a compile-time coordinate transformation system that automatically handles complex access patterns. + +The key insight is that transforming the logical 2D coordinates used to access LDS into a different 2D coordinate space ensures that threads accessing data simultaneously access different memory banks. This transformation is implemented through CK Tile's composable transform system, making it both flexible and efficient. See :ref:`ck_tile_transforms` and :ref:`ck_tile_coordinate_systems` for more information about the composable transform system. + +Coordinate Transformation Pipeline +================================== + +CK Tile performs coordinate transformations to bring LDS access from the original 2D position (M, K dimensions) into transformed (M', K') coordinates: + +Step 1: XOR Transform +--------------------- + +The original K coordinate is split into K0 and K1, where K1 represents the thread vector size along the K dimension (KPack) and K0 is KPerBlock/KPack. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "3D LDS coordinate [K0, M, K1]" + K0["KPerBlock/KPack * MLdsLayer
K0"] + M["MPerBlock/MLdsLayer
M"] + K1["KPack
K1"] + end + + subgraph "XOR Transform" + XT["make_xor_transform"] + end + + subgraph "Update K0 with XOR transformation" + K01["KPerBlock/KPack * MLdsLayer
K0'"] + M1["MPerBlock/MLdsLayer
M"] + K11["KPack
K1"] + end + + K0 --> XT + M --> XT + K1 --> K11 + + XT --> K01 + XT --> M1 + + style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style M fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style M1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + + style K1 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style K11 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + +.. image:: diagrams/lds_index_swapping_1.svg + :alt: Diagram + :align: center + +The XOR transformation updates the K0 coordinate using the formula: + +.. code-block:: cpp + + K0' = K0 ^ (M % (KPerBlock / KPack * MLdsLayer)) + +This XOR operation redistributes accesses across memory banks by mixing bits from the M and K dimensions. + +Step 2: Unmerge Transform +------------------------- + +The transformed K0' is split into L and K0'' components, creating an intermediate 4D coordinate space. This is necessary when MLdsLayer > 1, allowing multiple rows to share the same set of memory banks for better utilization with smaller tile sizes. + + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "3D LDS coordinate [K0', M, K1]" + K0["KPerBlock/KPack * MLdsLayer
K0'"] + M["MPerBlock/MLdsLayer
M"] + K1["KPack
K1"] + end + + subgraph "Unmerge into 2 components" + UM["make_unmerge_transform"] + end + + subgraph "4D intermediate transformation space" + L["MLdsLayer
L"] + M1["MPerBlock/MLdsLayer
M"] + K01["KPerBlock/KPack
K0''"] + K11["KPack
K1"] + end + + K0 --> UM + M --> M1 + K1 --> K11 + + UM --> L + UM --> K01 + + style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style L fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + + style M fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style M1 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + style K1 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style K11 fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + + +.. image:: diagrams/lds_index_swapping_2.svg + :alt: Diagram + :align: center + +The unmerge operation: + +.. code-block:: cpp + + L = K0' / (KPerBlock/KPack) + K0'' = K0' % (KPerBlock/KPack) + +When MLdsLayer == 1, this simplifies to L=0 and K0''=K0'. + +Step 3: Merge Transform +----------------------- + +The final step merges the 4D coordinates back into 2D transformed coordinates (M', K'). + + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "4D LDS Coordinates [L, M, K0'', K1]" + L["MLdsLayer
L"] + M1["MPerBlock/MLdsLayer
M"] + K0["KPerBlock/KPack
K0''"] + K1["KPack
K1"] + end + + subgraph "Merge into 1 component" + ME0["make_merge_transform"] + end + + subgraph "Merge into 1 component" + ME1["make_merge_transform"] + end + + subgraph "Transformed 2D coordinates [M', K']" + M11["MPerBlock
M'"] + K01["KPerBlock
K'"] + end + + L --> ME0 + M1 --> ME0 + + K0 --> ME1 + K1 --> ME1 + + ME0 --> M11 + ME1 --> K01 + + style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style K1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + + style M1 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style L fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style M11 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + +.. image:: diagrams/lds_index_swapping_3.svg + :alt: Diagram + :align: center + + +C++ Implementation +================== + +Here's how the complete transformation chain is implemented in CK Tile using :ref:`ck_tile_descriptors` and transforms: + +.. code-block:: cpp + + template + struct LdsIndexSwapping { + static constexpr index_t KPerBlock_over_KPack = KPerBlock / KPack; + static constexpr index_t MPerBlock_over_MLdsLayer = MPerBlock / MLdsLayer; + + // Step 1: Create base descriptor + using BaseLengths = Sequence< + KPerBlock_over_KPack * MLdsLayer, + MPerBlock_over_MLdsLayer, + KPack + >; + using BaseStrides = Sequence< + KPack, + KPerBlock * MLdsLayer, + 1 + >; + + using BaseDescriptor = TensorDescriptor; + + // Step 2: Apply XOR transform + using PermutedDescriptor = decltype( + transform_tensor_descriptor( + BaseDescriptor{}, + make_tuple( + make_xor_transform( + Sequence{} + ), + make_pass_through_transform(Number{}) + ), + Sequence<1, 0>{}, // XOR on dims [1,0] + Sequence<2>{} // Pass through dim 2 + ) + ); + + // Step 3: Apply unmerge and final transforms + using FinalDescriptor = decltype( + transform_tensor_descriptor( + PermutedDescriptor{}, + make_tuple( + make_unmerge_transform( + Sequence{} + ), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}) + ), + Sequence<0>{}, // Unmerge dim 0 + Sequence<1>{}, // Pass through dim 1 + Sequence<2>{}, // Pass through dim 2 + Sequence<0, 2>{}, // Output dims from unmerge + Sequence<1>{}, // Output dim 1 + Sequence<3>{} // Output dim 3 + ) + ); + }; + + + + +Practical Usage in GEMM +========================== + +Here's how LDS index swapping is used in a real GEMM kernel. See :ref:`ck_tile_gemm_optimization` for more information about GEMM optimization. + +.. code-block:: cpp + + template + __global__ void gemm_kernel_with_lds_swapping( + const DataType* __restrict__ a_global, + const DataType* __restrict__ b_global, + DataType* __restrict__ c_global, + index_t M, index_t N, index_t K) + { + // Shared memory allocation + __shared__ DataType a_lds[BlockM * BlockK]; + __shared__ DataType b_lds[BlockK * BlockN]; + + // Create LDS descriptor with index swapping + constexpr index_t MLdsLayer = 2; // Typical value for bank conflict avoidance + + using ALdsDesc = typename LdsIndexSwapping< + BlockK, KPack, MLdsLayer, BlockM + >::FinalDescriptor; + + // Load from global to LDS with swapped indices + auto load_a_to_lds = [&](index_t k_offset) { + // Each thread loads its portion + index_t tid = threadIdx.x; + constexpr index_t NumThreads = blockDim.x; + constexpr index_t ElementsPerThread = (BlockM * BlockK) / NumThreads; + + #pragma unroll + for (index_t i = 0; i < ElementsPerThread; ++i) { + index_t linear_idx = tid * ElementsPerThread + i; + + // Convert linear index to 2D coordinates + index_t m_idx = linear_idx / BlockK; + index_t k_idx = linear_idx % BlockK; + + // Load from global memory + DataType value = a_global[ + (blockIdx.y * BlockM + m_idx) * K + k_offset + k_idx + ]; + + // Store to LDS using swapped coordinates + ALdsDesc desc; + index_t lds_offset = desc.calculate_offset({ + 0, // L component (for this example) + m_idx / MLdsLayer, // M component + k_idx / KPack, // K0 component + k_idx % KPack // K1 component + }); + + a_lds[lds_offset] = value; + } + }; + + // Main GEMM computation loop + for (index_t k = 0; k < K; k += BlockK) { + // Load tiles to LDS with index swapping + load_a_to_lds(k); + __syncthreads(); + + // Compute using swapped LDS layout + // ... (matrix multiplication using transformed coordinates) + } + } + +Bank Conflict Analysis +====================== + +The effectiveness of index swapping can be analyzed by examining access patterns: + +.. code-block:: cpp + + template + struct BankConflictAnalyzer { + static constexpr index_t NumBanks = 32; + static constexpr index_t BankWidth = 4; // 4 bytes per bank + + template + static void analyze_access_pattern() { + // Simulate warp access pattern + index_t bank_access[NumBanks] = {0}; + + // Each thread in warp accesses one element + for (index_t tid = 0; tid < WarpSize; ++tid) { + // Calculate coordinates for this thread + index_t m_coord = tid / 8; // Example mapping + index_t k_coord = tid % 8; + + // Get LDS offset using descriptor + LdsDescriptor desc; + index_t offset = desc.calculate_offset({m_coord, k_coord}); + + // Determine bank + index_t bank = (offset * sizeof(float) / BankWidth) % NumBanks; + bank_access[bank]++; + } + + // Check for conflicts + index_t max_conflict = 0; + for (index_t bank = 0; bank < NumBanks; ++bank) { + max_conflict = max(max_conflict, bank_access[bank]); + } + + printf("Max bank conflict: %d-way\n", max_conflict); + } + }; + +Performance Benefits +==================== + +LDS index swapping provides several key benefits: + +1. **Conflict-Free Access**: Eliminates or significantly reduces bank conflicts +2. **Higher Throughput**: Enables full memory bandwidth utilization +3. **Automatic Optimization**: Transformation parameters can be tuned per architecture +4. **Composability**: Integrates seamlessly with other CK Tile transformations + +Advanced Configurations +======================= + +Different configurations can be used based on tile sizes and data types: + +.. code-block:: cpp + + // Configuration for different scenarios + template + struct LdsSwappingConfig { + // Smaller tiles may need different MLdsLayer + static constexpr index_t MLdsLayer = + (TileSize <= 32) ? 1 : + (TileSize <= 64) ? 2 : 4; + + // Adjust KPack based on data type + static constexpr index_t KPack = + sizeof(DataType) == 2 ? 8 : // FP16/BF16 + sizeof(DataType) == 4 ? 4 : 2; // FP32 + + // Validate configuration + static_assert(TileSize % (MLdsLayer * KPack) == 0, + "Tile size must be divisible by MLdsLayer * KPack"); + }; + + +Integration with Tile Distribution +================================== + +LDS index swapping works seamlessly with CK Tile's distribution system. See :ref:`ck_tile_tile_distribution` for more information about CK Tile's distribution system. + +.. code-block:: cpp + + template + struct DistributedLdsAccess { + using LdsDesc = typename LdsIndexSwapping<...>::FinalDescriptor; + + __device__ void load_from_lds( + const float* lds_ptr, + StaticDistributedTensor& reg_tensor) + { + // Each thread loads its distributed portion + auto coord = make_tensor_coordinate(LdsDesc{}, {0, 0, 0, 0}); + + #pragma unroll + for (index_t i = 0; i < reg_tensor.size(); ++i) { + // Calculate swapped LDS coordinates for this element + auto [m, k] = TileDistribution::get_local_tile_index(i); + + // Move to correct position + move_tensor_coordinate(LdsDesc{}, coord, {0, m, k/4, k%4}); + + // Load with transformed coordinates + reg_tensor[i] = lds_ptr[coord.get_offset()]; + } + } + }; + +Summary +======= + +LDS index swapping in CK Tile provides a effective and flexible solution to the bank conflict problem: + +- **Generalized XOR Preshuffle**: Extends the basic XOR technique through composable transforms +- **Multi-Step Pipeline**: Coordinates flow through XOR → Unmerge → Merge transformations +- **Automatic Optimization**: Parameters like MLdsLayer adapt to tile sizes and data types +- **Zero Overhead**: All transformations resolve at compile time +- **Seamless Integration**: Works naturally with other CK Tile components + +By understanding and utilizing LDS index swapping, kernels can achieve maximum shared memory bandwidth, which is often the limiting factor in GPU kernel performance. The transformation-based approach makes it easy to experiment with different swapping strategies while maintaining code clarity. + +For practical examples of how index swapping is used in complete kernels, see :ref:`ck_tile_swizzling_example`. For more on coordinate operations used here, see :ref:`ck_tile_coordinate_movement` and :ref:`ck_tile_tensor_coordinates`. diff --git a/docs/conceptual/ck_tile/load_store_traits.rst b/docs/conceptual/ck_tile/load_store_traits.rst new file mode 100644 index 0000000000..f9555a8bfe --- /dev/null +++ b/docs/conceptual/ck_tile/load_store_traits.rst @@ -0,0 +1,480 @@ +.. _ck_tile_load_store_traits: + +LoadStoreTraits - Memory Access Optimization Engine +=================================================== + +Overview +-------- + +LoadStoreTraits is a critical optimization component that analyzes :ref:`tile distributions ` to determine the most efficient memory access patterns. It serves as the engine behind :ref:`TileWindow's ` high-performance data movement, automatically identifying the best dimension for vectorization and creating optimized access sequences using :ref:`space-filling curves `. + +At compile time, LoadStoreTraits performs compile-time analysis of the distribution pattern to extract key information about memory access opportunities. This analysis determines how many elements can be loaded or stored in a single instruction, which dimension provides the best vectorization opportunity, and what traversal order maximizes cache utilization. The result is a set of compile-time constants and methods that guide the runtime execution of load and store operations. + +Key Concepts +------------ + +Vectorization Selection +~~~~~~~~~~~~~~~~~~~~~~~ + +LoadStoreTraits analyzes tensor dimensions to find the optimal one for vectorized loads and stores, prioritizing: + +- **Contiguous memory access** (stride = 1) +- **Maximum vector length** based on data type and :ref:`hardware capabilities ` +- **Alignment requirements** for efficient memory transactions + +Space-Filling Curve Integration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The system automatically creates a :ref:`space-filling curve ` that maximizes cache utilization while respecting vectorization constraints. This ensures that consecutive memory accesses are spatially close, reducing cache misses and improving memory bandwidth utilization. + +Access Pattern Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +LoadStoreTraits manages the trade-off between vector size and number of memory accesses, finding a solution that minimizes total memory transactions while maximizing bandwidth utilization. + +C++ Implementation +------------------ + +The LoadStoreTraits class analyzes distribution patterns at compile time: + +.. code-block:: cpp + + template + struct load_store_traits + { + // Compile-time analysis results + static constexpr index_t ndim_y = Distribution::ndim_y; + static constexpr index_t ndim_x = Distribution::ndim_x; + + // Find which Y dimension has stride 1 (best for vectorization) + static constexpr index_t vector_dim_y = []() { + // Complex compile-time analysis to find optimal dimension + const auto strides = Distribution::calculate_y_strides(); + for (index_t i = 0; i < ndim_y; ++i) { + if (strides[i] == 1) return i; + } + return ndim_y - 1; // Default to last dimension + }(); + + // Calculate how many scalars fit in a vector + static constexpr index_t scalar_per_vector = []() { + // Determine based on data type and hardware capabilities + if constexpr (sizeof(DataType) == 4) { // float32 + return min(Distribution::get_y_length(vector_dim_y), 4); + } else if constexpr (sizeof(DataType) == 2) { // float16 + return min(Distribution::get_y_length(vector_dim_y), 8); + } + return 1; + }(); + + // Total scalars accessed per memory operation + static constexpr index_t scalars_per_access = scalar_per_vector; + + // Space-filling curve for optimal traversal + // See :ref:`ck_tile_space_filling_curve` for details + using sfc_type = space_filling_curve; + static constexpr sfc_type sfc_ys = make_space_filling_curve(); + + // Total number of accesses needed + static constexpr index_t num_access = + Distribution::get_num_of_element_y() / scalars_per_access; + + // Get Y indices for a given access + CK_TILE_DEVICE constexpr auto get_y_indices(index_t i_access) const + { + return sfc_ys.get_index(i_access); + } + + // Get detailed vectorized access information + CK_TILE_DEVICE constexpr auto get_vectorized_access_info(index_t i_access) const + { + const auto base_indices = get_y_indices(i_access); + // Return structure with base indices, vector dimension, and size + return vectorized_access_info{ + base_indices, + vector_dim_y, + scalar_per_vector + }; + } + }; + +Vectorization Selection Algorithm +--------------------------------- + +LoadStoreTraits employs an advanced algorithm to select the best dimension for vectorization: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TD + A[Analyze Distribution] --> B{Check Each Dimension} + B --> C[Calculate Stride] + C --> D{Stride == 1?} + D -->|Yes| E[Candidate for Vectorization] + D -->|No| F[Skip Dimension] + E --> G[Check Alignment] + G --> H[Check Vector Size] + H --> I[Score Dimension] + F --> B + I --> J[Select Best Dimension] + J --> K[Configure Vector Access] + + style A fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style J fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style K fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + +.. image:: diagrams/load_store_traits_1.svg + :alt: Diagram + :align: center + +**Example: Comparing Different Memory Layouts** + +.. code-block:: cpp + + // Row-major layout [4×16] + using RowMajorDist = tile_distribution_encoding< + sequence<>, // No replication + tuple, sequence<4, 4>>, // 4x16 total + tuple, sequence<1>>, // Thread mapping + tuple, sequence<0>>, // Minor indices + sequence<2, 4>, // Y-space per thread + sequence<1, 1> // Y-space minor + >; + + // Column-major layout [16×4] + using ColMajorDist = tile_distribution_encoding< + sequence<>, // No replication + tuple, sequence<2, 2>>, // 16x4 total + tuple, sequence<1>>, // Thread mapping + tuple, sequence<0>>, // Minor indices + sequence<4, 2>, // Y-space per thread + sequence<1, 1> // Y-space minor + >; + + // LoadStoreTraits analysis + using RowTraits = load_store_traits; + using ColTraits = load_store_traits; + + // Row-major: vectorizes dimension 1 (4 elements) + static_assert(RowTraits::vector_dim_y == 1); + static_assert(RowTraits::scalar_per_vector == 4); + + // Column-major: vectorizes dimension 1 (2 elements) + static_assert(ColTraits::vector_dim_y == 1); + static_assert(ColTraits::scalar_per_vector == 2); + +Memory Access Patterns +---------------------- + +LoadStoreTraits creates efficient access patterns using space-filling curves: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Linear Traversal" + L1["0→1→2→3"] + L2["4→5→6→7"] + L3["Cache miss"] + L4["8→9→10→11"] + end + + subgraph "Snake Pattern" + S1["0→1→2→3"] + S2["7←6←5←4"] + S3["Cache hit!"] + S4["8→9→10→11"] + end + + L1 --> L2 + L2 --> L3 + L3 --> L4 + + S1 --> S2 + S2 --> S3 + S3 --> S4 + + style L3 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style S3 fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + +.. image:: diagrams/load_store_traits_2.svg + :alt: Diagram + :align: center + +**C++ Access Pattern Example:** + +.. code-block:: cpp + + // Create a 6x8 tile distribution + using TileDist = tile_distribution_encoding< + sequence<>, + tuple, sequence<2, 4>>, // 6x8 tile + tuple, sequence<1>>, + tuple, sequence<0>>, + sequence<3, 4>, // 3x4 per thread + sequence<1, 1> + >; + + using Traits = load_store_traits; + + // Access pattern visualization + template + CK_TILE_DEVICE void visualize_access_pattern() + { + printf("Tile: %dx%d\n", TileDist::get_tile_m(), TileDist::get_tile_n()); + printf("Vector dimension: %d\n", Traits::vector_dim_y); + printf("Scalars per access: %d\n", Traits::scalars_per_access); + printf("\nAccess sequence:\n"); + + // Show first few accesses + static_for<0, min(6, Traits::num_access), 1>{}([](auto i) { + const auto indices = Traits::get_y_indices(i); + const auto info = Traits::get_vectorized_access_info(i); + + printf("Access %d: Base=[%d,%d], Vector size=%d\n", + i, indices[0], indices[1], info.vector_size); + }); + } + +Performance Analysis +-------------------- + +Memory Access Efficiency +~~~~~~~~~~~~~~~~~~~~~~~~ + +LoadStoreTraits optimizes for several performance metrics: + +.. code-block:: cpp + + template + struct memory_access_analyzer + { + using Traits = load_store_traits; + + // Calculate memory bandwidth utilization + static constexpr float bandwidth_utilization() + { + constexpr index_t bytes_per_access = Traits::scalar_per_vector * sizeof(float); + constexpr index_t cache_line_size = 64; // bytes + return static_cast(bytes_per_access) / cache_line_size * 100.0f; + } + + // Calculate total memory transactions + static constexpr index_t total_transactions() + { + return Traits::num_access; + } + + // Check coalescing efficiency (see :ref:`ck_tile_gpu_basics`) + static constexpr bool is_perfectly_coalesced() + { + // Perfect coalescing when adjacent threads access adjacent memory + return Traits::vector_dim_y == Distribution::ndim_y - 1 && + Traits::scalar_per_vector >= 4; + } + }; + +Comparing Different Configurations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Configuration 1: Simple 8x8 tile + using Simple8x8 = tile_distribution_encoding< + sequence<>, + tuple, sequence<2, 4>>, + tuple, sequence<1>>, + tuple, sequence<0>>, + sequence<4, 4>, + sequence<1, 1> + >; + + // Configuration 2: Optimized for vectorization + using OptimizedVector = tile_distribution_encoding< + sequence<>, + tuple, sequence<2, 8>>, + tuple, sequence<1>>, + tuple, sequence<0>>, + sequence<2, 8>, // 2x8 per thread for better vectorization + sequence<1, 1> + >; + + // Analysis + using SimpleAnalyzer = memory_access_analyzer; + using OptimizedAnalyzer = memory_access_analyzer; + + static_assert(SimpleAnalyzer::bandwidth_utilization() == 25.0f); // 4*4/64 + static_assert(OptimizedAnalyzer::bandwidth_utilization() == 50.0f); // 8*4/64 + + // Better bandwidth utilization leads to improved performance + // See :ref:`ck_tile_gemm_optimization` for real-world examples + +Integration with Space-Filling Curves +------------------------------------- + +LoadStoreTraits automatically configures space-filling curves for optimal access: + +.. code-block:: cpp + + template + struct space_filling_curve_optimizer + { + using Traits = load_store_traits; + + static constexpr auto create_optimized_curve() + { + // Move vector dimension to end of access order + array dim_order; + + // Fill non-vector dimensions first + index_t pos = 0; + for (index_t i = 0; i < Distribution::ndim_y; ++i) { + if (i != Traits::vector_dim_y) { + dim_order[pos++] = i; + } + } + + // Vector dimension last for contiguous access + dim_order[pos] = Traits::vector_dim_y; + + // Create space-filling curve with optimized order + return space_filling_curve{ + Distribution::get_y_lengths(), + dim_order, + Traits::scalar_per_vector, + true // Enable snake pattern + }; + } + }; + +Advanced Optimizations +---------------------- + +Multi-Level Vectorization +~~~~~~~~~~~~~~~~~~~~~~~~~ + +For complex :ref:`distributions `, LoadStoreTraits can identify multiple levels of vectorization: + +.. code-block:: cpp + + template + struct multi_level_vectorization + { + // Primary vector dimension (innermost, stride 1) + static constexpr index_t primary_vector_dim = + load_store_traits::vector_dim_y; + + // Secondary vector dimension (next best option) + static constexpr index_t secondary_vector_dim = []() { + const auto strides = Distribution::calculate_y_strides(); + for (index_t i = 0; i < Distribution::ndim_y; ++i) { + if (i != primary_vector_dim && + strides[i] <= 4) { // Small stride + return i; + } + } + return -1; + }(); + + // Can use 2D vectorization? + static constexpr bool supports_2d_vector = secondary_vector_dim >= 0; + }; + +Adaptive Vector Size Selection +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +LoadStoreTraits adapts vector size based on multiple factors: + +.. code-block:: cpp + + template + struct adaptive_vector_size + { + static constexpr index_t calculate_optimal_vector_size() + { + constexpr index_t dim_length = + Distribution::get_y_length(load_store_traits::vector_dim_y); + + // Hardware-specific vector sizes + constexpr array valid_sizes = {8, 4, 2, 1}; + + // Find largest valid size that divides dimension length + for (auto size : valid_sizes) { + if (dim_length % size == 0 && + size * sizeof(DataType) <= 32) { // Max vector register size + return size; + } + } + return 1; + } + }; + +Best Practices +-------------- + +1. **Design Distributions for Vectorization** + + .. code-block:: cpp + + // Good: Inner dimension is power of 2 + using GoodDist = tile_distribution_encoding< + sequence<>, + tuple, sequence<2, 8>>, // Inner dim = 16 + tuple, sequence<1>>, + tuple, sequence<0>>, + sequence<2, 8>, // 8 elements for vectorization + sequence<1, 1> + >; + +2. **Consider Data Type Size** + + .. code-block:: cpp + + // Adjust distribution based on data type + template + using AdaptiveDist = std::conditional_t< + sizeof(DataType) == 2, // FP16 + tile_distribution_encoding<...>, // 8-wide vectors + tile_distribution_encoding<...> // 4-wide vectors for FP32 + >; + +3. **Align for Cache Lines** + + .. code-block:: cpp + + // Ensure tile dimensions align with cache lines + static_assert(TileDist::get_tile_n() * sizeof(float) % 64 == 0, + "Tile width should align to cache lines"); + + For more optimization techniques, see :ref:`ck_tile_lds_bank_conflicts` and :ref:`ck_tile_lds_index_swapping`. + +Summary +------- + +LoadStoreTraits provides: + +- **Automatic vectorization analysis**: Identifies optimal dimensions and vector sizes +- **Space-filling curve optimization**: Creates cache-friendly access patterns. See :ref:`ck_tile_space_filling_curve` for more information. +- **Compile-time optimization**: All analysis done at compile time for zero overhead +- **Hardware adaptation**: Adjusts to different data types and :ref:`architectures ` +- **Performance transparency**: Clear metrics for memory efficiency + +The compile-time analysis performed by LoadStoreTraits ensures that every memory operation in CK Tile achieves near-optimal performance, making it a critical component in the high-performance computing stack. + +Next Steps +---------- + +- :ref:`ck_tile_space_filling_curve` - Deep dive into traversal patterns +- :ref:`ck_tile_tile_window` - How LoadStoreTraits enables efficient data access +- :ref:`ck_tile_static_distributed_tensor` - The target of optimized loads/stores +- :ref:`ck_tile_coordinate_systems` - Understanding the coordinate transformations +- :ref:`ck_tile_gemm_optimization` - Real-world application of LoadStoreTraits diff --git a/docs/conceptual/ck_tile/space_filling_curve.rst b/docs/conceptual/ck_tile/space_filling_curve.rst new file mode 100644 index 0000000000..4b95f71a69 --- /dev/null +++ b/docs/conceptual/ck_tile/space_filling_curve.rst @@ -0,0 +1,511 @@ +.. _ck_tile_space_filling_curve: + +Space-Filling Curves - Optimal Memory Traversal +=============================================== + +Overview +-------- + +The SpaceFillingCurve (SFC) class provides a systematic way to traverse multi-dimensional tensors, supporting both scalar and vectorized access patterns. This is particularly important for optimizing memory access patterns in :ref:`GPU kernels `, where the order of memory accesses can dramatically impact performance through cache utilization, memory coalescing, and prefetching effectiveness. + +A space-filling curve is a continuous curve that visits every point in a multi-dimensional space exactly once. In the context of CK Tile, it defines a mapping from a 1D access index to multi-dimensional :ref:`tensor coordinates `, enabling efficient traversal patterns that maximize hardware utilization. + +Key Concepts +------------ + +Tensor Traversal +~~~~~~~~~~~~~~~~ + +The space-filling curve defines a mapping from a 1D access index to multi-dimensional tensor coordinates. This abstraction allows complex multi-dimensional access patterns to be expressed as simple linear iterations, while maintaining optimal memory access characteristics. + +Vectorized Access +~~~~~~~~~~~~~~~~~ + +:ref:`GPUs ` support vector load and store instructions that can access multiple consecutive elements in a single operation. SpaceFillingCurve supports this by allowing specification of how many elements to access per dimension (``scalars_per_access``), enabling efficient utilization of these hardware features. + +Dimension Ordering +~~~~~~~~~~~~~~~~~~ + +The order in which dimensions are traversed impacts memory access patterns. Row-major vs column-major ordering, for example, can mean the difference between the preferred sequential memory access and strided access which can potentially cause cache thrashing. + +Snake Patterns +~~~~~~~~~~~~~~ + +Snake, or serpentine, patterns reverse the traversal direction on alternate rows and planes, keeping consecutive accesses spatially close. This pattern is particularly effective for maintaining cache locality when moving between rows or higher-dimensional boundaries. + +Usage +~~~~~ + +SFC mainly uses Tile Transpose, Tile shuffling iteration, and CShuffle to access the tile data in the discrete way the application requires and have the best cache memory coherent hit. + +C++ Implementation +------------------ + +The C++ template class provides compile-time optimization of traversal patterns: + +.. code-block:: cpp + + template + struct space_filling_curve + { + static constexpr index_t ndim = NDimSFC; + static constexpr auto tensor_lengths = SFCLengths{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto scalars_per_access = ScalarsPerAccess{}; + static constexpr bool snake_curved = IsSnakeCurved; + + // Calculate access dimensions (with ceiling division) + static constexpr auto access_lengths = []() { + array lengths; + for (index_t i = 0; i < ndim; ++i) { + lengths[i] = (tensor_lengths[i] + scalars_per_access[i] - 1) + / scalars_per_access[i]; + } + return lengths; + }(); + + // Total number of accesses needed + static constexpr index_t get_num_of_access() + { + index_t total = 1; + for (index_t i = 0; i < ndim; ++i) { + total *= access_lengths[i]; + } + return total; + } + + // Convert 1D access index to N-D coordinates + CK_TILE_DEVICE constexpr auto get_index(index_t i_access) const + { + array indices; + + // Calculate indices in access space + index_t remaining = i_access; + for (index_t i = ndim - 1; i >= 0; --i) { + const index_t dim = dim_access_order[i]; + indices[dim] = remaining % access_lengths[dim]; + remaining /= access_lengths[dim]; + } + + // Apply snake pattern if enabled + if constexpr (snake_curved) { + apply_snake_pattern(indices); + } + + // Scale by scalars_per_access + for (index_t i = 0; i < ndim; ++i) { + indices[i] *= scalars_per_access[i]; + } + + return indices; + } + + // Calculate step between two accesses + CK_TILE_DEVICE constexpr auto get_step_between( + index_t start, index_t end) const + { + const auto start_idx = get_index(start); + const auto end_idx = get_index(end); + + array step; + for (index_t i = 0; i < ndim; ++i) { + step[i] = end_idx[i] - start_idx[i]; + } + return step; + } + }; + +Basic Usage Examples +-------------------- + +Scalar Access Patterns +~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Row-major traversal of 4x6 matrix + using RowMajorCurve = space_filling_curve< + 2, // 2D + sequence<4, 6>, // Shape: 4x6 + sequence<0, 1>, // Dimension order: row then column + sequence<1, 1>, // Scalar access + false // No snake pattern + >; + + // Total accesses needed + constexpr index_t num_access = RowMajorCurve::get_num_of_access(); // 24 + + // Access pattern (first 10) + static_for<0, 10, 1>{}([](auto i) { + constexpr auto indices = RowMajorCurve{}.get_index(i); + printf("Access %d: [%d, %d]\n", i, indices[0], indices[1]); + }); + // Output: [0,0], [0,1], [0,2], [0,3], [0,4], [0,5], [1,0], [1,1], ... + +Vectorized Access Patterns +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Vector-4 access on dimension 1 + using VectorizedCurve = space_filling_curve< + 2, // 2D + sequence<4, 8>, // Shape: 4x8 + sequence<0, 1>, // Row-major + sequence<1, 4>, // Vector-4 on dimension 1 + false + >; + + // Access pattern visualization + static_for<0, VectorizedCurve::get_num_of_access(), 1>{}([](auto i) { + constexpr auto indices = VectorizedCurve{}.get_index(i); + printf("Access %d: row %d, cols [%d:%d]\n", + i, indices[0], indices[1], indices[1] + 3); + }); + // Output: row 0, cols [0:3], row 0, cols [4:7], row 1, cols [0:3], ... + +Column-Major vs Row-Major +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Compare access patterns + using RowMajor = space_filling_curve<2, sequence<4, 6>, + sequence<0, 1>, sequence<1, 1>, false>; + using ColMajor = space_filling_curve<2, sequence<4, 6>, + sequence<1, 0>, sequence<1, 1>, false>; + + // Row-major: [0,0], [0,1], [0,2], ... (traverse rows) + // Col-major: [0,0], [1,0], [2,0], ... (traverse columns) + +Advanced Patterns +----------------- + +Snake Pattern for Cache Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The snake pattern reverses traversal direction on alternate rows, minimizing the distance between consecutive accesses: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Linear Pattern" + L1["Row 0: →"] + L2["Row 1: →"] + L3["Jump back"] + L4["Row 2: →"] + end + + subgraph "Snake Pattern" + S1["Row 0: →"] + S2["Row 1: ←"] + S3["Continue"] + S4["Row 2: →"] + end + + L1 --> L3 + L3 --> L2 + L2 --> L3 + L3 --> L4 + + S1 --> S2 + S2 --> S4 + + style L3 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style S3 fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + +.. image:: diagrams/space_filling_curve.svg + :alt: Diagram + :align: center + +.. code-block:: cpp + + using SnakeCurve = space_filling_curve< + 2, + sequence<4, 8>, + sequence<0, 1>, + sequence<1, 1>, + true // Enable snake pattern + >; + + // Access pattern with snake: + // Row 0: [0,0], [0,1], [0,2], ..., [0,7] + // Row 1: [1,7], [1,6], [1,5], ..., [1,0] (reversed!) + // Row 2: [2,0], [2,1], [2,2], ..., [2,7] + // Row 3: [3,7], [3,6], [3,5], ..., [3,0] (reversed!) + +GEMM Tile Access Pattern +~~~~~~~~~~~~~~~~~~~~~~~~ + +For :ref:`matrix multiplication `, optimal access patterns are crucial: + +.. code-block:: cpp + + // GEMM tile: 16x32 with vector-8 loads + // Column-major for coalesced access in GEMM + // See :ref:`ck_tile_gemm_optimization` for complete example + using GemmTileCurve = space_filling_curve< + 2, + sequence<16, 32>, // Tile size + sequence<1, 0>, // Column-major + sequence<1, 8>, // Vector-8 loads + false + >; + + // This creates a pattern where: + // - Each access loads 8 consecutive elements + // - Accesses proceed down columns (coalesced for column-major storage) + // - Total accesses: 16 * (32/8) = 64 + +3D Tensor Patterns +~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // 3D tensor with mixed vectorization + using Tensor3D = space_filling_curve< + 3, + sequence<4, 8, 16>, // 4x8x16 tensor + sequence<0, 1, 2>, // Access order + sequence<1, 2, 4>, // Different vector sizes per dimension + false + >; + + // Access pattern: + // - Dimension 0: scalar access + // - Dimension 1: vector-2 access + // - Dimension 2: vector-4 access + // Total accesses: 4 * (8/2) * (16/4) = 64 + +Performance Analysis +-------------------- + +Step Analysis for Memory Patterns +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Understanding step patterns between accesses is crucial for performance: + +.. code-block:: cpp + + template + struct access_pattern_analyzer + { + static constexpr void analyze_locality() + { + index_t sequential_steps = 0; + index_t cache_line_jumps = 0; + index_t large_jumps = 0; + + constexpr SFC sfc{}; + + for (index_t i = 0; i < SFC::get_num_of_access() - 1; ++i) { + const auto step = sfc.get_step_between(i, i + 1); + + // Calculate Manhattan distance + index_t distance = 0; + for (index_t d = 0; d < SFC::ndim; ++d) { + distance += abs(step[d]); + } + + if (distance <= 1) { + sequential_steps++; + } else if (distance <= 16) { // Within cache line + cache_line_jumps++; + } else { + large_jumps++; + } + } + + // Report statistics... + } + }; + +Optimizing for Hardware +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Optimize for GPU memory coalescing (see :ref:`ck_tile_gpu_basics`) + template + struct coalesced_access_pattern + { + // For coalescing, adjacent threads should access adjacent memory + static constexpr index_t vector_size = sizeof(float4) / sizeof(DataType); + + using OptimalPattern = space_filling_curve< + 2, + sequence, + sequence<1, 0>, // Column-major for coalescing + sequence<1, vector_size>, // Vectorized on fast-changing dimension + false + >; + }; + +Handling Edge Cases +------------------- + +Non-Divisible Dimensions +~~~~~~~~~~~~~~~~~~~~~~~~ + +When tensor dimensions aren't evenly divisible by vector size: + +.. code-block:: cpp + + // 5x7 tensor with 2x3 access pattern + using EdgeCaseCurve = space_filling_curve< + 2, + sequence<5, 7>, + sequence<0, 1>, + sequence<2, 3>, + false + >; + + // Access lengths use ceiling division: ceil(5/2) x ceil(7/3) = 3x3 + static_assert(EdgeCaseCurve::access_lengths[0] == 3); + static_assert(EdgeCaseCurve::access_lengths[1] == 3); + + // Boundary handling needed for accesses that exceed tensor bounds + template + CK_TILE_DEVICE void safe_access(index_t i_access) + { + const auto indices = SFC{}.get_index(i_access); + + // Check bounds for each dimension + bool in_bounds = true; + for (index_t d = 0; d < SFC::ndim; ++d) { + if (indices[d] + SFC::scalars_per_access[d] > SFC::tensor_lengths[d]) { + in_bounds = false; + break; + } + } + + if (in_bounds) { + // Full vector access + } else { + // Partial access with masking + } + } + +Integration with CK Tile +------------------------ + +LoadStoreTraits Integration +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:ref:`LoadStoreTraits ` uses space-filling curves to optimize memory access: + +.. code-block:: cpp + + template + struct load_store_traits + { + // Create optimized space-filling curve + // See :ref:`ck_tile_tile_distribution` for Distribution details + using sfc_type = space_filling_curve< + Distribution::ndim_y, + typename Distribution::y_lengths, + optimized_dim_order, // Computed order + optimized_scalars_per_access, + true // Enable snake for cache optimization + >; + + static constexpr sfc_type sfc_ys{}; + }; + +TileWindow Usage +~~~~~~~~~~~~~~~~ + +:ref:`TileWindow ` leverages space-filling curves for systematic tile traversal: + +.. code-block:: cpp + + template + CK_TILE_DEVICE void process_tile(const TileWindow& window) + { + using Traits = typename TileWindow::traits_type; + constexpr auto sfc = Traits::sfc_ys; + + // Traverse tile using space-filling curve + static_for<0, sfc.get_num_of_access(), 1>{}([&](auto i) { + const auto indices = sfc.get_index(i); + // Process element at indices... + }); + } + +Best Practices +-------------- + +1. **Choose Appropriate Dimension Order** + + .. code-block:: cpp + + // For row-major storage, use row-major traversal + using RowMajorSFC = space_filling_curve<2, Shape, sequence<0, 1>, ...>; + + // For column-major storage, use column-major traversal + using ColMajorSFC = space_filling_curve<2, Shape, sequence<1, 0>, ...>; + +2. **Optimize Vector Size** + + .. code-block:: cpp + + // Match vector size to cache line for optimal bandwidth + // See :ref:`ck_tile_lds_bank_conflicts` for cache optimization + constexpr index_t optimal_vector = min( + tensor_length_fast_dim, + cache_line_size / sizeof(DataType) + ); + +3. **Enable Snake Pattern for Large Tensors** + + .. code-block:: cpp + + // Snake pattern helps when jumping between rows/planes + using CacheFriendlySFC = space_filling_curve< + NDim, Lengths, Order, Scalars, + true // Enable snake + >; + +4. **Consider Memory Layout** + + .. code-block:: cpp + + // Align access patterns with physical memory layout + static_assert( + SFC::scalars_per_access[fastest_dim] * sizeof(DataType) + % cache_line_size == 0, + "Vector access should align with cache lines" + ); + +Summary +------- + +Space-filling curves provide: + +- **Systematic traversal**: Convert N-D access to 1D iteration +- **Vectorization support**: Efficient use of vector load and store instructions +- **Cache optimization**: Snake patterns and dimension ordering for locality +- **Flexibility**: Adaptable to different :ref:`tensor shapes ` and access patterns +- **Performance**: Compile-time optimization with zero runtime overhead + +The advanced traversal patterns enabled by space-filling curves are fundamental to achieving high performance in GPU kernels, ensuring that memory access patterns align with :ref:`hardware capabilities `. + +Next Steps +---------- + +- :ref:`ck_tile_load_store_traits` - How curves optimize memory access +- :ref:`ck_tile_sweep_tile` - Traversing distributed tensors +- :ref:`ck_tile_static_distributed_tensor` - The data structures being traversed +- :ref:`ck_tile_tile_window` - Integration with data loading +- :ref:`ck_tile_gemm_optimization` - Real-world application example diff --git a/docs/conceptual/ck_tile/static_distributed_tensor.rst b/docs/conceptual/ck_tile/static_distributed_tensor.rst new file mode 100644 index 0000000000..bfd50c0899 --- /dev/null +++ b/docs/conceptual/ck_tile/static_distributed_tensor.rst @@ -0,0 +1,429 @@ +.. meta:: + :description: CK Tile static distributed tensor documentation + :keywords: CK Tile, static distributed tensor, thread-local storage, GPU programming, ROCM + +.. _ck_tile_static_distributed_tensor: + +************************* +Static Distributed Tensor +************************* + +Overview +======== + +Static distributed tensors represent the thread-local data containers in CK Tile's programming model. Unlike traditional GPU programming where developers manually manage thread-local arrays and coordinate access patterns, static distributed tensors provide a high-level abstraction that automatically handles data distribution across threads while maintaining the performance characteristics of register-based storage. + +Each thread in a workgroup owns a portion of the overall tensor data, stored in its registers or local memory. The distribution pattern follows the :ref:`tile distribution ` rules, ensuring that collective operations across all threads reconstruct the complete logical tensor while individual threads operate only on their local portions. + +This design enables three critical optimizations: + + * It maximizes register utilization by keeping frequently accessed data in the fastest memory hierarchy. + * It eliminates redundant memory accesses since each thread maintains its own working set. + * It provides a clean abstraction for complex algorithms like matrix multiplication where each thread accumulates partial results that eventually combine into the final output. + +Thread-Local Storage Model +========================== + +The static distributed tensor implements an advanced storage model that maps multi-dimensional tensor data to thread-local arrays: + +.. code-block:: cpp + + template + struct StaticDistributedTensor { + // Each thread stores its portion of the tensor + static constexpr index_t kNumElements = + TileDistribution::GetNumElementsPerThread(); + + // Thread-local storage - typically maps to registers + DataType data_[kNumElements]; + + // Access using Y-space coordinates (see :ref:`ck_tile_coordinate_systems`) + __device__ DataType& operator()(const YIndex& idx) { + // Convert Y coordinate to local buffer offset + index_t offset = TileDistribution::YToLocalOffset(idx); + return data_[offset]; + } + }; + +The storage layout follows these principles: + +1. **Contiguous Storage**: Each thread's data is stored in a contiguous array, optimizing register allocation and enabling vectorized operations. + +2. **Deterministic Mapping**: The Y-coordinate to buffer offset mapping is computed at compile time, eliminating runtime overhead. + +3. **Alignment Guarantees**: The storage layout respects hardware alignment requirements for efficient memory operations. + +Memory Layout and Access Patterns +================================= + +Understanding how static distributed tensors organize memory is important for performance optimization. Consider a 2D tensor distributed across a 2D thread block: + +.. code-block:: cpp + + // Define a 64x64 tensor distributed across 16x16 threads + using TileDist = TileDistribution< + Sequence<64, 64>, // Tensor dimensions + Sequence<16, 16> // Thread block dimensions + >; + + // Each thread owns a 4x4 subtensor + using MyTensor = StaticDistributedTensor; + + __device__ void example_kernel() { + MyTensor accumulator; + + // Initialize thread-local data + for(index_t i = 0; i < 4; ++i) { + for(index_t j = 0; j < 4; ++j) { + // Y-space coordinates for this thread's elements + YIndex y_idx = make_tuple( + threadIdx.y * 4 + i, + threadIdx.x * 4 + j + ); + accumulator(y_idx) = 0.0f; + } + } + } + +The memory layout follows a hierarchical pattern: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TD + A[Global Tensor 64x64] --> B[Thread Block 16x16] + B --> C[Thread 0,0
Elements 0:3,0:3] + B --> D[Thread 0,1
Elements 0:3,4:7] + B --> E[Thread 1,0
Elements 4:7,0:3] + B --> F[...] + + C --> G[Local Array
16 elements] + D --> H[Local Array
16 elements] + E --> I[Local Array
16 elements] + + + + + +.. image:: diagrams/static_distributed_tensor.svg + :alt: Diagram + :align: center + +Element Access and Indexing +=========================== + +Static distributed tensors provide multiple indexing modes to support different access patterns: + +.. code-block:: cpp + + template + class StaticDistributedTensor { + public: + // Y-space indexing (most common) - see :ref:`ck_tile_coordinate_systems` + __device__ DataType& operator()(const YIndex& y_idx) { + return data_[YToOffset(y_idx)]; + } + + // Direct buffer indexing (for vectorized operations) + __device__ DataType& operator[](index_t offset) { + return data_[offset]; + } + + // Structured access for multi-dimensional patterns + template + __device__ DataType& at(Coords... coords) { + YIndex y_idx = make_tuple(coords...); + return (*this)(y_idx); + } + + // Vectorized access for performance + template + __device__ auto get_vector(index_t offset) { + using VectorType = vector_type_t; + return *reinterpret_cast(&data_[offset]); + } + }; + +The indexing system supports several optimization strategies: + +1. **Compile-Time Resolution**: When indices are known at compile time, the compiler can optimize away all indexing calculations. + +2. **Vectorized Access**: Accessing multiple elements as vectors enables efficient register-to-register transfers. + +3. **Boundary Checking**: Debug builds include automatic boundary checking to catch indexing errors early. + +Thread Coordination and Synchronization +======================================= + +Static distributed tensors excel at patterns where threads cooperate to process larger data structures: + +.. code-block:: cpp + + // Matrix multiplication accumulator pattern + // See :ref:`ck_tile_gemm_optimization` for complete example + template + __device__ void gemm_accumulate( + const TileWindow& a_window, + const TileWindow& b_window, + StaticDistributedTensor& c_accumulator) + { + // Each thread accumulates its portion + constexpr index_t kInnerTiles = 8; + + for(index_t k = 0; k < kInnerTiles; ++k) { + // Load tiles from global memory + auto a_tile = a_window.load(k); + auto b_tile = b_window.load(k); + + // Synchronize to ensure all loads complete + __syncthreads(); + + // Local accumulation (no synchronization needed) + for(index_t i = 0; i < 4; ++i) { + for(index_t j = 0; j < 4; ++j) { + CType sum = 0; + for(index_t kk = 0; kk < 4; ++kk) { + sum += a_tile(i, kk) * b_tile(kk, j); + } + c_accumulator.at(i, j) += sum; + } + } + } + } + +Key coordination patterns include: + +1. **Accumulation**: Each thread maintains partial results that combine to form the final answer. + +2. **Scatter/Gather**: Threads can efficiently reorganize data through coordinated read/write patterns. + +3. **Reduction**: Tree-based reduction algorithms naturally map to the distributed storage model. + +Practical Usage Patterns +======================== + +Static distributed tensors are useful in many common GPU programming patterns: + +**1. Register Blocking for Matrix Operations** + +.. code-block:: cpp + + // Optimize register usage for small matrix tiles + template + struct RegisterTile { + using Distribution = TileDistribution< + Sequence, + Sequence<1, 1> // Single thread owns entire tile + >; + using Tensor = StaticDistributedTensor; + + __device__ void compute() { + Tensor tile; + // All M*N elements in registers of one thread + // Enables aggressive unrolling and scheduling + } + }; + +**2. Warp-Level Primitives** + +.. code-block:: cpp + + // Distribute across warp for collaborative operations + template + struct WarpDistributedVector { + using Distribution = TileDistribution< + Sequence<32>, // 32 elements + Sequence<32> // 32 threads in warp + >; + using Tensor = StaticDistributedTensor; + + __device__ T warp_reduce_sum() { + Tensor data; + // Each thread has one element + // Use warp shuffle for reduction + T value = data[0]; + for(int offset = 16; offset > 0; offset /= 2) { + value += __shfl_down_sync(0xffffffff, value, offset); + } + return value; + } + }; + +**3. Shared Memory Staging** + +.. code-block:: cpp + + // Combine with shared memory for complex patterns + // See :ref:`ck_tile_lds_bank_conflicts` for LDS optimization + template + struct StagedComputation { + using RegTensor = StaticDistributedTensor; + using LdsTensor = StaticDistributedTensor; + + __device__ void process() { + RegTensor reg_data; + __shared__ T shared_buffer[1024]; + + // Stage 1: Compute in registers + compute_local(reg_data); + + // Stage 2: Exchange through shared memory + store_to_lds(reg_data, shared_buffer); + __syncthreads(); + + // Stage 3: Load different pattern + LdsTensor lds_data; + load_from_lds(shared_buffer, lds_data); + } + }; + +Performance Considerations +========================== + +Optimizing static distributed tensor usage requires understanding several :ref:`performance factors `: + +**Register Pressure**: Each thread's local storage typically maps to registers. Excessive storage requirements can cause register spilling: + +.. code-block:: cpp + + // Monitor register usage + template + struct RegisterAnalysis { + static constexpr index_t kRegistersPerElement = sizeof(T) / 4; + static constexpr index_t kTotalRegisters = Size * kRegistersPerElement; + + static_assert(kTotalRegisters <= 64, + "Exceeds typical register budget"); + }; + +**Memory Coalescing**: When loading/storing distributed tensors, ensure access patterns promote coalescing. See :ref:`ck_tile_gpu_basics` for more information about coalescing. + +.. code-block:: cpp + + // Good: Coalesced access pattern + template + __device__ void coalesced_store(Tensor& tensor, float* global_ptr) { + index_t tid = threadIdx.x + blockIdx.x * blockDim.x; + #pragma unroll + for(index_t i = 0; i < Tensor::kNumElements; ++i) { + global_ptr[tid + i * gridDim.x * blockDim.x] = tensor[i]; + } + } + +**Instruction Scheduling**: Organize operations to maximize instruction-level parallelism: + +.. code-block:: cpp + + // Interleave independent operations + template + __device__ void optimized_accumulate(Tensor& acc, + const Tensor& a, + const Tensor& b) { + #pragma unroll + for(index_t i = 0; i < Tensor::kNumElements; i += 4) { + // Group independent operations + float tmp0 = a[i+0] * b[i+0]; + float tmp1 = a[i+1] * b[i+1]; + float tmp2 = a[i+2] * b[i+2]; + float tmp3 = a[i+3] * b[i+3]; + + // Accumulate after multiplies complete + acc[i+0] += tmp0; + acc[i+1] += tmp1; + acc[i+2] += tmp2; + acc[i+3] += tmp3; + } + } + +Integration with CK Tile Ecosystem +================================== + +Static distributed tensors integrate seamlessly with other CK Tile components: + +.. code-block:: cpp + + // Complete example: Distributed GEMM kernel + template + __global__ void distributed_gemm_kernel( + const float* __restrict__ a_ptr, + const float* __restrict__ b_ptr, + float* __restrict__ c_ptr, + index_t M, index_t N, index_t K) + { + // Define distributions + constexpr index_t kTileM = 128; + constexpr index_t kTileN = 128; + constexpr index_t kTileK = 32; + + using ATileDist = TileDistribution< + Sequence, + Sequence<32, 8> + >; + using BTileDist = TileDistribution< + Sequence, + Sequence<8, 32> + >; + using CTileDist = TileDistribution< + Sequence, + Sequence<32, 32> + >; + + // Create distributed accumulator + StaticDistributedTensor c_accumulator; + + // Initialize to zero + #pragma unroll + for(index_t i = 0; i < c_accumulator.kNumElements; ++i) { + c_accumulator[i] = 0.0f; + } + + // Main GEMM loop + for(index_t k_tile = 0; k_tile < K; k_tile += kTileK) { + // Create tile windows for this iteration + // See :ref:`ck_tile_tile_window` for details + auto a_window = make_tile_window( + a_ptr, ALayout{M, K}, + ATileDist{}, + {blockIdx.y * kTileM, k_tile} + ); + + auto b_window = make_tile_window( + b_ptr, BLayout{K, N}, + BTileDist{}, + {k_tile, blockIdx.x * kTileN} + ); + + // Load tiles to distributed tensors + // See :ref:`ck_tile_load_store_traits` for optimized loading + auto a_tile = a_window.load(); + auto b_tile = b_window.load(); + + // Distributed matrix multiply + distributed_gemm_accumulate(a_tile, b_tile, c_accumulator); + } + + // Store results + auto c_window = make_tile_window( + c_ptr, CLayout{M, N}, + CTileDist{}, + {blockIdx.y * kTileM, blockIdx.x * kTileN} + ); + c_window.store(c_accumulator); + } + +Summary +======= + +Static distributed tensors provide the foundation for high-performance thread-local computation in CK Tile. By abstracting the complexities of register allocation, thread coordination, and memory access patterns, they enable developers to write clear, maintainable code that achieves hardware efficiency. The key benefits include: + +- **Automatic Distribution**: The :ref:`tile distribution ` system handles all thread-to-data mapping +- **Register Efficiency**: Thread-local storage maps directly to registers when possible +- **Zero-Overhead Abstraction**: All distribution logic resolves at compile time +- **Seamless Integration**: Works naturally with :ref:`tile windows `, :ref:`descriptors `, and other CK Tile components +- **Performance Transparency**: The storage model makes performance characteristics clear and predictable + +When combined with the broader CK Tile ecosystem, static distributed tensors enable the construction of complex GPU kernels that match hand-tuned assembly performance while maintaining the clarity of high-level mathematical expressions. diff --git a/docs/conceptual/ck_tile/sweep_tile.rst b/docs/conceptual/ck_tile/sweep_tile.rst new file mode 100644 index 0000000000..4dfb6a2ad1 --- /dev/null +++ b/docs/conceptual/ck_tile/sweep_tile.rst @@ -0,0 +1,560 @@ +.. meta:: + :description: CK Tile sweep operations documentation + :keywords: CK Tile, sweep operations, tile iteration, GPU programming + +.. _ck_tile_sweep_tile: + +********** +Sweep Tile +********** + +Overview +======== + +Sweep operations are the clean way to iterate over distributed data in CK Tile. They complete the tile distribution workflow by providing clean, efficient iteration patterns that automatically handle all the complex indexing details. Sweep operations are similar to ``forEach()`` operation. Sweep operations call a function for every data element. + +Sweep operations use the "load once, use many times" pattern. Load X data once into registers, then sweep through Y positions while keeping X in fast memory. This maximizes data reuse and minimizes memory bandwidth requirements. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "X-Tile (Reused)" + XT["X data loaded once
Stays in registers"] + end + + subgraph "Y-Sweep" + Y1["Y position 0"] + Y2["Y position 1"] + Y3["Y position 2"] + YN["Y position N"] + end + + subgraph "Computation" + C["Process(X, Y)"] + end + + XT --> C + Y1 --> C + Y2 --> C + Y3 --> C + YN --> C + + style XT fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style C fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + + + + + +.. image:: diagrams/sweep_tile_1.svg + :alt: Diagram + :align: center + +The Complete GPU Workflow +========================= + +Sweep operations are the final piece of the distributed computing puzzle: + +1. **TileDistribution**: "Here's how to divide work" +2. **TileWindow**: "Here's the data, loaded efficiently" +3. **Sweep Operations**: "Here's how to process every element" +4. **User code**: "Thanks! *does computation*" + +Without sweep operations, manual nested loops and complex index calculations are required, increasing the risk of missing elements or double-processing. Sweep operations provide lambda-based iteration with automatic handling of all elements. + +See :ref:`ck_tile_coordinate_systems` for more information about coordinate systems. + +Basic Sweep Implementation +========================== + +The fundamental sweep pattern in C++: + +.. code-block:: cpp + + template + __device__ void sweep_tile( + const DistributedTensor& tensor, + Func&& func) + { + // Get Y-space dimensions + constexpr auto y_lengths = tensor.get_tile_distribution() + .get_y_vector_lengths(); + + // Generate nested loops at compile time + static_for<0, y_lengths.size(), 1>{}([&](auto i) { + sweep_tile_impl(tensor, func, make_tuple()); + }); + } + + // Recursive implementation for arbitrary dimensions + template + __device__ void sweep_tile_impl( + const DistributedTensor& tensor, + Func&& func, + tuple indices) + { + constexpr auto y_lengths = tensor.get_tile_distribution() + .get_y_vector_lengths(); + + if constexpr (Dim == y_lengths.size()) { + // Base case: call function with complete indices + func(make_multi_index(indices...)); + } else { + // Recursive case: iterate this dimension + static_for<0, y_lengths[Dim], 1>{}([&](auto i) { + sweep_tile_impl( + tensor, func, + tuple_cat(indices, make_tuple(i)) + ); + }); + } + } + + + +Memory Efficiency Pattern +========================= + +The sweep pattern provides significant memory efficiency benefits. This is particularly important for GPU architectures (see :ref:`ck_tile_gpu_basics`) where memory bandwidth is often the limiting factor: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Traditional Approach" + T1["Load X[0]"] --> P1["Process"] + T2["Load Y[0]"] --> P1 + T3["Load X[0]"] --> P2["Process"] + T4["Load Y[1]"] --> P2 + T5["Load X[0]"] --> P3["Process"] + T6["Load Y[2]"] --> P3 + Note1["X loaded 3 times!"] + end + + subgraph "Sweep Approach" + S1["Load X[0]"] --> SP["Process with
Y[0], Y[1], Y[2]"] + S2["Load Y[0,1,2]"] --> SP + Note2["X loaded once!"] + end + + style Note1 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style Note2 fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + +.. image:: diagrams/sweep_tile_2.svg + :alt: Diagram + :align: center + +Practical Sweep Patterns +======================== + +Pattern 1: Simple Element Processing +------------------------------------ + +This pattern demonstrates the basic usage with :ref:`ck_tile_static_distributed_tensor`: + +.. code-block:: cpp + + template + __device__ void simple_sweep_example( + StaticDistributedTensor& input, + StaticDistributedTensor& output) + { + // Process each element + sweep_tile(input, [&](auto y_indices) { + DataType value = input.get_element(y_indices); + DataType result = compute_function(value); + output.set_element(y_indices, result); + }); + } + +Pattern 2: Accumulation +----------------------- + +.. code-block:: cpp + + template + __device__ DataType sweep_accumulate( + const StaticDistributedTensor& tensor) + { + DataType sum = 0; + + sweep_tile(tensor, [&](auto y_indices) { + sum += tensor.get_element(y_indices); + }); + + return sum; + } + +Pattern 3: Conditional Processing +--------------------------------- + +.. code-block:: cpp + + template + __device__ void conditional_sweep( + StaticDistributedTensor& tensor, + DataType threshold) + { + sweep_tile(tensor, [&](auto y_indices) { + DataType value = tensor.get_element(y_indices); + if (value > threshold) { + // Process only values above threshold + tensor.set_element(y_indices, process_large_value(value)); + } + }); + } + + +GEMM Sweep Pattern +================== + +The sweep pattern is fundamental to high-performance matrix multiplication. See :ref:`ck_tile_gemm_optimization` for more information about GEMM optimization details. + +.. code-block:: cpp + + template + __device__ void gemm_sweep_tile( + const TileWindow& a_window, + const TileWindow& b_window, + TileWindow& c_window) + { + // Phase 1: Load A tile into registers (X dimension) + auto a_tile = make_static_distributed_tensor(); + a_window.load(a_tile); // Load once, reuse many times + + // Phase 2: Create C accumulator + auto c_accumulator = make_static_distributed_tensor(); + + // Initialize accumulator + sweep_tile(c_accumulator, [&](auto y_indices) { + c_accumulator.set_element(y_indices, 0); + }); + + // Phase 3: Sweep through B positions (Y dimension) + constexpr index_t k_per_block = BDistribution::get_lengths()[1]; + + for (index_t k = 0; k < k_per_block; ++k) { + // Load current B slice + auto b_slice = make_static_distributed_tensor(); + b_window.load_slice(b_slice, k); + + // Compute C += A * B for this slice + sweep_tile(c_accumulator, [&](auto c_indices) { + CDataType sum = c_accumulator.get_element(c_indices); + + // Inner product for this C element + constexpr index_t inner_dim = ADistribution::get_lengths()[1]; + for (index_t i = 0; i < inner_dim; ++i) { + auto a_indices = make_multi_index(c_indices[0], i); + auto b_indices = make_multi_index(i, c_indices[1]); + + sum += a_tile.get_element(a_indices) * + b_slice.get_element(b_indices); + } + + c_accumulator.set_element(c_indices, sum); + }); + } + + // Phase 4: Store result + c_window.store(c_accumulator); + } + +Advanced Sweep Patterns +======================= + +Multi-Dimensional Sweep +----------------------- + +.. code-block:: cpp + + template + __device__ void tensor_3d_sweep( + StaticDistributedTensor& tensor) + { + // Sweep through 3D tensor with nested structure + sweep_tile(tensor, [&](auto indices) { + // indices is MultiIndex<3> with [d0, d1, d2] + index_t d0 = indices[0]; + index_t d1 = indices[1]; + index_t d2 = indices[2]; + + // Process based on 3D position + DataType value = tensor.get_element(indices); + + // Example: Different processing for different planes + if (d2 == 0) { + // First plane: special processing + value = special_process(value); + } else { + // Other planes: normal processing + value = normal_process(value); + } + + tensor.set_element(indices, value); + }); + } + +Strided Sweep +------------- + +.. code-block:: cpp + + template + __device__ void strided_sweep( + const DistributedTensor& tensor, + Func&& func) + { + constexpr auto y_lengths = tensor.get_tile_distribution() + .get_y_vector_lengths(); + + // Sweep with stride in first dimension + static_for<0, y_lengths[0], Stride>{}([&](auto i) { + // Create indices for this strided position + auto indices = make_multi_index(i); + + // Complete remaining dimensions normally + sweep_remaining_dims<1>(tensor, func, indices); + }); + } + +Block Sweep for Cache Optimization +---------------------------------- + +This pattern leverages shared memory to avoid :ref:`ck_tile_lds_bank_conflicts`: + +.. code-block:: cpp + + template + __device__ void block_sweep_pattern( + StaticDistributedTensor& tensor) + { + constexpr auto y_lengths = tensor.get_tile_distribution() + .get_y_vector_lengths(); + constexpr index_t num_blocks = (y_lengths[0] + BlockSize - 1) / BlockSize; + + // Process in blocks for better cache utilization + static_for<0, num_blocks, 1>{}([&](auto block_id) { + constexpr index_t block_start = block_id * BlockSize; + constexpr index_t block_end = min(block_start + BlockSize, y_lengths[0]); + + // Load block data into shared memory + __shared__ DataType block_cache[BlockSize][y_lengths[1]]; + + // Cooperative load + static_for{}([&](auto i) { + static_for<0, y_lengths[1], 1>{}([&](auto j) { + auto indices = make_multi_index(i, j); + block_cache[i - block_start][j] = tensor.get_element(indices); + }); + }); + + __syncthreads(); + + // Process from cache + static_for<0, block_end - block_start, 1>{}([&](auto i) { + static_for<0, y_lengths[1], 1>{}([&](auto j) { + DataType value = block_cache[i][j]; + value = complex_process(value); + + auto indices = make_multi_index(block_start + i, j); + tensor.set_element(indices, value); + }); + }); + }); + } + +Performance Characteristics +=========================== + +Sweep operations provide several performance benefits: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Sweep Performance Benefits" + B1["Zero runtime overhead
Compile-time unrolling"] + B2["Perfect memory coalescing
Sequential access patterns"] + B3["Automatic vectorization
Compiler optimizations"] + B4["Register reuse
X data stays in VGPR"] + end + + subgraph "Use Cases" + U1["Matrix Multiplication
Reuse A columns"] + U2["Convolution
Reuse filter weights"] + U3["Reduction
Accumulate over Y"] + U4["Broadcast
Apply X to all Y"] + end + + B1 --> Performance["High Performance"] + B2 --> Performance + B3 --> Performance + B4 --> Performance + + Performance --> U1 + Performance --> U2 + Performance --> U3 + Performance --> U4 + + style Performance fill:#d1fae5,stroke:#10b981,stroke-width:3px + + + + + +.. image:: diagrams/sweep_tile_3.svg + :alt: Diagram + :align: center + +Compiler Optimizations +---------------------- + +Using :ref:`ck_tile_load_store_traits` and :ref:`ck_tile_space_filling_curve` enables optimal memory access patterns: + +.. code-block:: cpp + + // The compiler can optimize sweep patterns effectively + template + __device__ void optimized_sweep_example( + StaticDistributedTensor& tensor) + { + // This sweep pattern: + sweep_tile(tensor, [&](auto indices) { + tensor.set_element(indices, tensor.get_element(indices) * 2.0f); + }); + + // Compiles to something like: + // #pragma unroll + // for (index_t i = 0; i < tensor.size(); ++i) { + // tensor[i] *= 2.0f; + // } + + // With: + // - Complete unrolling for small tensors + // - Vectorized loads/stores + // - No function call overhead + // - Perfect instruction scheduling + } + +Integration with CK Tile Components +=================================== + +Complete workflow example: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart TB + subgraph "Complete Workflow" + TD["TileDistribution
Define data layout"] + TW["TileWindow
Create view"] + DT["DistributedTensor
Load X data"] + ST["SweepTile
Iterate Y positions"] + R["Results
Store outputs"] + end + + TD --> TW + TW --> DT + DT --> ST + ST --> R + + style TD fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style ST fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style R fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + +.. image:: diagrams/sweep_tile_4.svg + :alt: Diagram + :align: center + +.. code-block:: cpp + + template + __global__ void complete_tile_kernel( + const DataType* input, + DataType* output, + index_t M, index_t N) + { + // 1. Define distribution + constexpr index_t BlockM = 64; + constexpr index_t BlockN = 64; + + using Distribution = TileDistribution< + Sequence, + Sequence<16, 16> + >; + + // 2. Create tile windows + auto input_window = make_tile_window( + input, make_tuple(M, N), + make_tuple(blockIdx.y * BlockM, blockIdx.x * BlockN), + Distribution{} + ); + + auto output_window = make_tile_window( + output, make_tuple(M, N), + make_tuple(blockIdx.y * BlockM, blockIdx.x * BlockN), + Distribution{} + ); + + // 3. Load input tile + auto input_tile = make_static_distributed_tensor(); + input_window.load(input_tile); + + // 4. Create output tile + auto output_tile = make_static_distributed_tensor(); + + // 5. Process with sweep + sweep_tile(input_tile, [&](auto indices) { + DataType value = input_tile.get_element(indices); + DataType result = complex_computation(value); + output_tile.set_element(indices, result); + }); + + // 6. Store results + output_window.store(output_tile); + } + + +Summary +======= + +SweepTile provides clean and efficient iteration over distributed data: + +- **Efficiency**: Load once, use many times pattern +- **Simplicity**: Clean lambda-based iteration abstraction +- **Performance**: Zero overhead with perfect access patterns +- **Flexibility**: Various sweep patterns for different algorithms + +Key benefits: + +1. **Memory Bandwidth**: Optimal reuse of loaded data +2. **Register Pressure**: Keep hot data in fastest memory +3. **Code Clarity**: Express algorithms naturally +4. **Compiler Optimization**: Enable aggressive optimizations + +The sweep pattern is fundamental to high-performance GPU kernels, turning complex iteration patterns into simple, efficient operations. Combined with TileDistribution and TileWindow, sweep operations complete the toolkit for clean and performant GPU computing. diff --git a/docs/conceptual/ck_tile/swizzling_example.rst b/docs/conceptual/ck_tile/swizzling_example.rst new file mode 100644 index 0000000000..f74038c954 --- /dev/null +++ b/docs/conceptual/ck_tile/swizzling_example.rst @@ -0,0 +1,495 @@ +.. meta:: + :description: CK Tile memory swizzling with Morton ordering example + :keywords: CK Tile, swizzling, Morton ordering, Z-order curve, GPU optimization + +.. _ck_tile_swizzling_example: + +************************************** +Memory Swizzling with Morton Ordering +************************************** + +Overview +======== + +This chapter demonstrates a practical application of tensor descriptors for implementing memory swizzling patterns, specifically Morton ordering (Z-order curve) within tiles. Memory swizzling is used to optimize GPU memory access patterns and reduce :ref:`bank conflicts `. Morton ordering provides a space-filling curve that maintains spatial locality while enabling efficient parallel access. See :ref:`ck_tile_space_filling_curve` for more information about parallel access. + +Morton ordering is widely used in: + +- **GPU Texture Memory**: Optimizing cache efficiency for 2D texture access +- **Matrix Operations**: Reducing memory bank conflicts in shared memory +- **Image Processing**: Improving locality for block-based algorithms +- **Scientific Computing**: Enhancing data access patterns for stencil operations + +Understanding Morton Ordering +============================= + +Morton ordering interleaves the bits of 2D coordinates to create a 1D ordering that preserves spatial locality. For a 2D coordinate (y, x), we split each coordinate into its binary bits and interleave them: + +- y = y₁y₀ (2 bits) +- x = x₁x₀ (2 bits) +- Morton index = y₁x₁y₀x₀ (4 bits) + +This creates a Z-shaped traversal pattern within each tile: + +.. code-block:: cpp + + // Morton encoding for 2D coordinates + template + __host__ __device__ index_t morton_encode_2d(index_t y, index_t x) { + index_t result = 0; + for (index_t i = 0; i < NumBits; ++i) { + index_t bit_y = (y >> i) & 1; + index_t bit_x = (x >> i) & 1; + result |= (bit_y << (2*i + 1)) | (bit_x << (2*i)); + } + return result; + } + + // Morton decoding back to 2D coordinates + template + __host__ __device__ void morton_decode_2d( + index_t morton_idx, + index_t& y, + index_t& x) + { + y = 0; + x = 0; + for (index_t i = 0; i < NumBits; ++i) { + y |= ((morton_idx >> (2*i + 1)) & 1) << i; + x |= ((morton_idx >> (2*i)) & 1) << i; + } + } + +Morton Pattern Analysis +----------------------- + +The Morton index layout in a 4×4 tile follows this pattern: + +.. code-block:: text + + Morton Index Layout: + 0 1 4 5 + 2 3 6 7 + 8 9 12 13 + 10 11 14 15 + +Bit pattern breakdown: + +.. code-block:: text + + (0,0) = (00, 00) → 0 = 0000 + (0,1) = (00, 01) → 1 = 0001 + (0,2) = (00, 10) → 4 = 0100 + (0,3) = (00, 11) → 5 = 0101 + (1,0) = (01, 00) → 2 = 0010 + (1,1) = (01, 01) → 3 = 0011 + (1,2) = (01, 10) → 6 = 0110 + (1,3) = (01, 11) → 7 = 0111 + +Stage 1: Tiling with UnmergeTransform +====================================== + +First, we split our texture into tiles using tensor descriptors (see :ref:`ck_tile_descriptors` and :ref:`ck_tile_transforms`). This creates a hierarchical structure: (Y_blk, y_in, X_blk, x_in). + +.. code-block:: cpp + + template + struct TiledTextureDescriptor { + static constexpr index_t NumTilesY = H / TileSize; + static constexpr index_t NumTilesX = W / TileSize; + + // Original descriptor for H×W texture + using BaseDesc = TensorDescriptor< + Sequence, + Sequence // Row-major layout + >; + + // Stage 1: Split into tiles + // Transform: [H, W] → [NumTilesY, TileSize, NumTilesX, TileSize] + using TiledDesc = decltype( + transform_tensor_descriptor( + BaseDesc{}, + make_tuple( + make_unmerge_transform(Sequence{}), + make_unmerge_transform(Sequence{}) + ), + Sequence<0>{}, // Y dimension + Sequence<1>{}, // X dimension + Sequence<0, 1>{}, // Y → (Y_blk, y_in) + Sequence<2, 3>{} // X → (X_blk, x_in) + ) + ); + }; + +Example usage for an 8×8 texture with 4×4 tiles: + +.. code-block:: cpp + + // Create tiled descriptor + using TiledDesc8x8 = TiledTextureDescriptor<8, 8, 4>::TiledDesc; + + // Access pattern: iterate tile by tile + template + __device__ void process_tiled_texture(const DataType* texture) { + TiledDesc8x8 desc; + + // Process each tile + for (index_t y_blk = 0; y_blk < 2; ++y_blk) { + for (index_t x_blk = 0; x_blk < 2; ++x_blk) { + // Process elements within tile + for (index_t y_in = 0; y_in < 4; ++y_in) { + for (index_t x_in = 0; x_in < 4; ++x_in) { + // Calculate offset using descriptor + index_t offset = desc.calculate_offset({ + y_blk, y_in, x_blk, x_in + }); + + DataType value = texture[offset]; + // Process value... + } + } + } + } + } + +Stage 2: Morton Ordering with MergeTransform +============================================ + +The key insight is that MergeTransform enables Morton ordering by reordering and merging coordinate bits. The transformation involves: + +1. Split coordinates into individual bits using UnmergeTransform +2. Reorder and merge bits using MergeTransform to create the Morton index + +This leverages the coordinate transformation system described in :ref:`ck_tile_coordinate_systems`. + +Mathematical Foundation +----------------------- + +.. code-block:: cpp + + template + struct MortonTransform { + static_assert(TileSize == 4, "This example assumes 4x4 tiles"); + + // Split 4 → (2, 2) for bit extraction + using SplitTransform = UnmergeTransform>; + + // Merge bits in Morton order: (y₀, x₀, y₁, x₁) → Morton + using MortonMergeTransform = MergeTransform>; + + // The merge operation computes: + // morton_idx = y₁×8 + x₁×4 + y₀×2 + x₀ + // This matches the bit interleaving pattern! + }; + +Complete Morton Implementation +------------------------------ + +Here's a complete implementation combining both stages: + +.. code-block:: cpp + + template + struct MortonSwizzledTexture { + static constexpr index_t NumTilesY = H / TileSize; + static constexpr index_t NumTilesX = W / TileSize; + + // Manual Morton implementation for reliability + __device__ static void apply_morton_swizzling( + const DataType* input, + DataType* output) + { + // Process each tile + for (index_t tile_y = 0; tile_y < NumTilesY; ++tile_y) { + for (index_t tile_x = 0; tile_x < NumTilesX; ++tile_x) { + // Apply Morton ordering within tile + for (index_t morton_idx = 0; morton_idx < TileSize * TileSize; ++morton_idx) { + // Decode Morton index to tile coordinates + index_t y_in, x_in; + morton_decode_2d<2>(morton_idx, y_in, x_in); + + // Calculate global coordinates + index_t global_y = tile_y * TileSize + y_in; + index_t global_x = tile_x * TileSize + x_in; + + // Calculate linear indices + index_t src_idx = global_y * W + global_x; + index_t dst_idx = (tile_y * NumTilesX + tile_x) * TileSize * TileSize + morton_idx; + + output[dst_idx] = input[src_idx]; + } + } + } + } + }; + +Memory Access Pattern Analysis +============================== + +An analysis of the benefits of Morton ordering for different access patterns: + +.. code-block:: cpp + + template + struct AccessPatternAnalyzer { + // Analyze spatial locality + __host__ static void analyze_morton_locality() { + printf("Morton Order Spatial Locality Analysis:\n"); + printf("Adjacent indices and their 2D distance:\n"); + + for (index_t i = 0; i < TileSize * TileSize - 1; ++i) { + index_t y1, x1, y2, x2; + morton_decode_2d<2>(i, y1, x1); + morton_decode_2d<2>(i + 1, y2, x2); + + index_t manhattan_dist = abs(y2 - y1) + abs(x2 - x1); + printf("Morton %2d→%2d: (%d,%d)→(%d,%d), distance: %d\n", + i, i+1, y1, x1, y2, x2, manhattan_dist); + } + } + + // Compare cache line usage + __host__ static void analyze_cache_efficiency() { + constexpr index_t CacheLineSize = 128; // bytes + constexpr index_t ElementSize = sizeof(float); + constexpr index_t ElementsPerCacheLine = CacheLineSize / ElementSize; + + printf("\nCache Efficiency Analysis:\n"); + printf("Cache line size: %d bytes (%d floats)\n", + CacheLineSize, ElementsPerCacheLine); + + // Row-major access + index_t row_major_lines = 0; + for (index_t y = 0; y < TileSize; ++y) { + for (index_t x = 0; x < TileSize; x += ElementsPerCacheLine) { + row_major_lines++; + } + } + + // Morton access + index_t morton_lines = 0; + index_t current_line = -1; + for (index_t i = 0; i < TileSize * TileSize; ++i) { + index_t y, x; + morton_decode_2d<2>(i, y, x); + index_t linear_idx = y * TileSize + x; + index_t cache_line = linear_idx / ElementsPerCacheLine; + + if (cache_line != current_line) { + morton_lines++; + current_line = cache_line; + } + } + + printf("Row-major: %d cache lines\n", row_major_lines); + printf("Morton: %d cache lines\n", morton_lines); + } + }; + +GPU Kernel Implementation +========================= + +A complete GPU kernel using Morton ordering for optimized memory access: + +.. code-block:: cpp + + template + __global__ void morton_optimized_kernel( + const DataType* __restrict__ input, + DataType* __restrict__ output, + index_t H, index_t W) + { + // Shared memory with Morton layout + __shared__ DataType smem[BlockSize * BlockSize]; + + // Thread and block indices + const index_t tid_x = threadIdx.x; + const index_t tid_y = threadIdx.y; + const index_t bid_x = blockIdx.x; + const index_t bid_y = blockIdx.y; + + // Global position + const index_t global_x = bid_x * BlockSize + tid_x; + const index_t global_y = bid_y * BlockSize + tid_y; + + // Load to shared memory with coalescing + if (global_x < W && global_y < H) { + smem[tid_y * BlockSize + tid_x] = input[global_y * W + global_x]; + } + __syncthreads(); + + // Process tiles with Morton ordering + constexpr index_t TilesPerBlock = BlockSize / TileSize; + + // Each thread processes one element in Morton order + const index_t tile_id = (tid_y / TileSize) * TilesPerBlock + (tid_x / TileSize); + const index_t morton_in_tile = (tid_y % TileSize) * TileSize + (tid_x % TileSize); + + // Decode Morton index + index_t y_in_tile, x_in_tile; + morton_decode_2d<2>(morton_in_tile, y_in_tile, x_in_tile); + + // Calculate position in shared memory + const index_t tile_y = tile_id / TilesPerBlock; + const index_t tile_x = tile_id % TilesPerBlock; + const index_t smem_y = tile_y * TileSize + y_in_tile; + const index_t smem_x = tile_x * TileSize + x_in_tile; + + // Process with Morton access pattern + DataType value = smem[smem_y * BlockSize + smem_x]; + + // Apply computation... + value = compute_function(value); + + // Store result + if (global_x < W && global_y < H) { + output[global_y * W + global_x] = value; + } + } + +Bank Conflict Reduction +======================= + +Morton ordering is particularly effective for reducing shared memory bank conflicts (complementing the XOR preshuffle technique described in :ref:`ck_tile_lds_index_swapping`): + +.. code-block:: cpp + + template + struct BankConflictAnalysis { + static constexpr index_t NumBanks = 32; + static constexpr index_t BankWidth = 4; // bytes + + template + __host__ static void analyze_bank_conflicts( + const char* pattern_name, + AccessPattern access_func) + { + index_t bank_access[NumBanks] = {0}; + + // Simulate warp access + for (index_t tid = 0; tid < WarpSize; ++tid) { + index_t offset = access_func(tid); + index_t bank = (offset * sizeof(float) / BankWidth) % NumBanks; + bank_access[bank]++; + } + + // Find maximum conflict + index_t max_conflict = 0; + for (index_t bank = 0; bank < NumBanks; ++bank) { + max_conflict = max(max_conflict, bank_access[bank]); + } + + printf("%s: %d-way bank conflict\n", pattern_name, max_conflict); + } + + __host__ static void compare_access_patterns() { + printf("Bank Conflict Analysis for 4x4 Tile Access:\n"); + + // Row-major access + analyze_bank_conflicts("Row-major", [](index_t tid) { + return (tid / 4) * 4 + (tid % 4); + }); + + // Morton access + analyze_bank_conflicts("Morton", [](index_t tid) { + index_t y, x; + morton_decode_2d<2>(tid % 16, y, x); + return y * 4 + x; + }); + } + }; + +Practical Applications +====================== + +Real-world usage of Morton ordering in CK Tile: + +**1. Texture Cache Optimization** + +.. code-block:: cpp + + template + struct TextureCacheOptimized { + static constexpr index_t TextureTileSize = 8; + + __device__ static DataType sample_2d_morton( + const DataType* texture, + float u, float v, + index_t width, index_t height) + { + // Convert normalized coordinates to texel coordinates + index_t x = u * width; + index_t y = v * height; + + // Determine tile + index_t tile_x = x / TextureTileSize; + index_t tile_y = y / TextureTileSize; + + // Position within tile + index_t x_in_tile = x % TextureTileSize; + index_t y_in_tile = y % TextureTileSize; + + // Convert to Morton index + index_t morton_idx = morton_encode_2d<3>(y_in_tile, x_in_tile); + + // Calculate final offset + index_t tile_offset = (tile_y * (width / TextureTileSize) + tile_x) + * TextureTileSize * TextureTileSize; + + return texture[tile_offset + morton_idx]; + } + }; + +**2. Matrix Multiplication with Swizzled Tiles** + +For complete GEMM optimization techniques, see :ref:`ck_tile_gemm_optimization`. + +.. code-block:: cpp + + template + struct SwizzledGEMM { + __device__ static void load_tile_morton( + const DataType* matrix, + DataType* tile, + index_t row_offset, + index_t col_offset, + index_t ld) + { + // Load tile with Morton ordering for better LDS bank utilization + #pragma unroll + for (index_t i = 0; i < TileM * TileN; ++i) { + index_t row_in_tile, col_in_tile; + morton_decode_2d<3>(i, row_in_tile, col_in_tile); + + if (row_in_tile < TileM && col_in_tile < TileN) { + index_t global_row = row_offset + row_in_tile; + index_t global_col = col_offset + col_in_tile; + tile[i] = matrix[global_row * ld + global_col]; + } + } + } + }; + +Summary +======= + +Morton ordering with CK Tile provides memory optimization capabilities: + +- **Spatial Locality**: Z-order curve maintains 2D locality in 1D memory layout +- **Bank Conflict Reduction**: Distributed access patterns across memory banks +- **Cache Efficiency**: Better utilization of cache lines for 2D access patterns +- **Mathematical Framework**: Tensor descriptors express swizzling cleanly +- **Practical Implementation**: Bit manipulation provides reliable results + +Key implementation insights: + +1. **MergeTransform** is essential for expressing Morton bit interleaving +2. **Manual bit manipulation** provides reliable and efficient implementation +3. **Tiling + Morton** combines hierarchical locality with local optimization +4. **GPU-specific tuning** adapts patterns to hardware characteristics + +The tensor descriptor approach provides the mathematical framework for expressing these complex memory patterns, while practical implementations often use direct bit manipulation for efficiency and reliability. + +For more examples of practical CK Tile usage, see :ref:`ck_tile_convolution_example`. For the underlying buffer and tensor abstractions, see :ref:`ck_tile_buffer_views` and :ref:`ck_tile_tensor_views`. diff --git a/docs/conceptual/ck_tile/tensor_coordinates.rst b/docs/conceptual/ck_tile/tensor_coordinates.rst new file mode 100644 index 0000000000..4e9240b83c --- /dev/null +++ b/docs/conceptual/ck_tile/tensor_coordinates.rst @@ -0,0 +1,459 @@ +.. meta:: + :description: CK Tile tensor coordinates and MultiIndex documentation + :keywords: CK Tile, MultiIndex, tensor coordinates, GPU programming + +.. _ck_tile_tensor_coordinates: + +******************* +Tensor Coordinates +******************* + +Overview +======== + +Before diving into transforms and adaptors (see :ref:`ck_tile_transforms` and :ref:`ck_tile_adaptors`), it's essential to understand the basic coordinate system in CK Tile. MultiIndex is a container that extends the C++ array with additional operations for multi-dimensional indexing. It is the fundamental building block used throughout the system. + +MultiIndex serves as the common currency between different coordinate spaces (see :ref:`ck_tile_coordinate_systems`), enabling seamless transformation and navigation through complex tensor layouts. Every transform, adaptor, and descriptor in CK Tile operates on these coordinate containers. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "MultiIndex Structure" + MI["MultiIndex
Container for N integers"] + D0["Dimension 0"] + D1["Dimension 1"] + D2["Dimension 2"] + DN["Dimension N-1"] + end + + subgraph "Usage Context" + T["Transforms
"] + A["Adaptors
"] + TV["Tensors
"] + end + + MI --> D0 + MI --> D1 + MI --> D2 + MI --> DN + + T --> MI + A --> MI + TV --> MI + + style MI fill:#f3e5f5,stroke:#7b1fa2,stroke-width:3px + style D0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style D1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style D2 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style DN fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style T fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style A fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style TV fill:#ffebee,stroke:#d32f2f,stroke-width:2px + + + +.. image:: diagrams/tensor_coordinates_1.svg + :alt: Diagram + :align: center + +MultiIndex Implementation +========================= + +The C++ implementation provides both compile-time and runtime flexibility: + +.. code-block:: cpp + + // Basic MultiIndex structure + template + struct MultiIndex { + static constexpr index_t kNDim = NDim; + + // Storage for coordinate values + array data_; + + // Constructors + __host__ __device__ constexpr MultiIndex() : data_{} {} + + __host__ __device__ constexpr MultiIndex( + const array& values) : data_(values) {} + + // Element access + __host__ __device__ constexpr index_t& operator[](index_t i) { + return data_[i]; + } + + __host__ __device__ constexpr const index_t& operator[](index_t i) const { + return data_[i]; + } + + // Size query + __host__ __device__ static constexpr index_t size() { + return NDim; + } + }; + +Creating and Using MultiIndex +============================= + +CK Tile provides convenient factory functions for creating MultiIndex objects: + +.. code-block:: cpp + + #include + + __device__ void example_multiindex_usage() { + // Create 3D coordinate with runtime values + auto coord = make_multi_index(1, 2, 3); + + // Access dimensions + auto x = coord[0]; // Returns 1 + auto y = coord[1]; // Returns 2 + auto z = coord[2]; // Returns 3 + + // For compile-time coordinates, use number<> + auto coord_static = make_multi_index( + number<1>{}, number<2>{}, number<3>{} + ); + + // Create from tuple + auto shape = make_tuple(128, 256, 64); + auto coord2 = to_multi_index(shape); + + // Modify coordinate + auto new_coord = coord; + new_coord[0] = 5; // Set X to 5 + + // Use in tensor access + auto tensor = make_naive_tensor_view( + data_ptr, shape, strides + ); + + // Create tensor coordinate for access + auto tensor_coord = make_tensor_coordinate( + tensor.get_tensor_descriptor(), coord + ); + } + +For more advanced coordinate operations and movement patterns, see :ref:`ck_tile_coordinate_movement`. + +Compile-Time Optimization +------------------------- + +CK Tile leverages C++ templates for zero-overhead abstractions: + +.. code-block:: cpp + + // Compile-time MultiIndex operations + template + __host__ __device__ constexpr auto make_static_multi_index() { + return MultiIndex{array{Is...}}; + } + + // Example: Matrix access pattern + template + __device__ void optimized_matrix_access(float* matrix) { + // Compile-time coordinates + constexpr auto origin = make_static_multi_index<0, 0>(); + constexpr auto corner = make_static_multi_index(); + + // Loop unrolling with compile-time indices + #pragma unroll + for (index_t i = 0; i < M; ++i) { + #pragma unroll + for (index_t j = 0; j < N; ++j) { + auto coord = make_multi_index(i, j); + // Compiler can optimize based on known bounds + process_element(matrix[i * N + j]); + } + } + } + +MultiIndex in Coordinate Flow +============================= + +MultiIndex serves as the interface between user code and the transformation pipeline: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart TB + subgraph CF ["Coordinate Flow"] + direction LR + UI["User Input
[1, 2, 3]"] --> MI["MultiIndex
Storage"] + MI --> TR["Transform
Processing"] + TR --> MO["MultiIndex
Output"] + MO --> TA["Tensor Access
element(coord)"] + end + + subgraph EX ["Example: 3D Tensor Access"] + direction LR + T3D["3D Tensor
shape=[4,5,6]"] --> COORD["MultiIndex(3, [1,2,3])"] + COORD --> ELEM["Element at
position [1,2,3]"] + end + + style UI fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style MI fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + style MO fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + style COORD fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + +.. image:: diagrams/tensor_coordinates_2.svg + :alt: Diagram + :align: center + +Common Usage Patterns +===================== + +Pattern 1: Tensor Iteration +--------------------------- + +.. code-block:: cpp + + template + __device__ void iterate_2d_tensor(DataType* tensor) { + // Iterate through tensor using MultiIndex + for (index_t i = 0; i < M; ++i) { + for (index_t j = 0; j < N; ++j) { + auto coord = make_multi_index(i, j); + + // Use coordinate for structured access + DataType& element = tensor[coord[0] * N + coord[1]]; + + // Process element + element = process_value(element); + } + } + } + +Pattern 2: Boundary Checking +---------------------------- + +.. code-block:: cpp + + template + __device__ bool is_valid_coordinate( + const MultiIndex& coord, + const MultiIndex& shape) + { + for (index_t i = 0; i < NDim; ++i) { + if (coord[i] < 0 || coord[i] >= shape[i]) { + return false; + } + } + return true; + } + + // Usage in kernel + __global__ void safe_tensor_kernel(float* tensor, index_t H, index_t W) { + auto coord = make_multi_index( + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.x * blockDim.x + threadIdx.x + ); + + auto shape = make_multi_index(H, W); + + if (is_valid_coordinate(coord, shape)) { + tensor[coord[0] * W + coord[1]] = compute_value(coord); + } + } + +Pattern 3: Transform Chaining +----------------------------- + +.. code-block:: cpp + + // Apply multiple transformations to coordinates + template + __device__ auto apply_transform_chain( + const MultiIndex<2>& input_coord, + const Transform1& t1, + const Transform2& t2) + { + // First transformation + auto intermediate = t1.calculate_bottom_index(input_coord); + + // Second transformation + auto final = t2.calculate_bottom_index(intermediate); + + return final; + } + +Advanced MultiIndex Operations +============================== + +Arithmetic Operations +--------------------- + +.. code-block:: cpp + + template + struct MultiIndexOps { + // Element-wise addition + __device__ static MultiIndex add( + const MultiIndex& a, + const MultiIndex& b) + { + MultiIndex result; + #pragma unroll + for (index_t i = 0; i < NDim; ++i) { + result[i] = a[i] + b[i]; + } + return result; + } + + // Scalar multiplication + __device__ static MultiIndex scale( + const MultiIndex& coord, + index_t factor) + { + MultiIndex result; + #pragma unroll + for (index_t i = 0; i < NDim; ++i) { + result[i] = coord[i] * factor; + } + return result; + } + + // Dot product (for linear indexing) + __device__ static index_t dot( + const MultiIndex& coord, + const MultiIndex& strides) + { + index_t result = 0; + #pragma unroll + for (index_t i = 0; i < NDim; ++i) { + result += coord[i] * strides[i]; + } + return result; + } + }; + +Specialized Coordinates +----------------------- + +.. code-block:: cpp + + // Thread coordinate helper + struct ThreadCoordinate { + __device__ static auto get_thread_coord_1d() { + return make_multi_index( + blockIdx.x * blockDim.x + threadIdx.x + ); + } + + __device__ static auto get_thread_coord_2d() { + return make_multi_index( + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.x * blockDim.x + threadIdx.x + ); + } + + __device__ static auto get_thread_coord_3d() { + return make_multi_index( + blockIdx.z * blockDim.z + threadIdx.z, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.x * blockDim.x + threadIdx.x + ); + } + }; + +Integration with Tensor Operations +================================== + +MultiIndex is the foundation for all tensor operations in CK Tile (see :ref:`ck_tile_tensor_views` and :ref:`ck_tile_buffer_views` for tensor abstractions): + +.. code-block:: cpp + + template + __device__ void tensor_operation_example(TensorView& tensor) { + // Get tensor shape as MultiIndex + auto shape = tensor.get_tensor_descriptor().get_lengths(); + + // Create coordinate for center element + MultiIndex center; + #pragma unroll + for (index_t i = 0; i < TensorView::kNDim; ++i) { + center[i] = shape[i] / 2; + } + + // Access center element + auto center_value = tensor(center); + + // Create stencil pattern using MultiIndex + constexpr auto offsets = make_tuple( + make_multi_index(-1, 0), // North + make_multi_index( 1, 0), // South + make_multi_index( 0, -1), // West + make_multi_index( 0, 1) // East + ); + + // Apply stencil + auto sum = center_value; + static_for<0, 4, 1>{}([&](auto i) { + auto neighbor = MultiIndexOps<2>::add(center, get(offsets)); + if (is_valid_coordinate(neighbor, shape)) { + sum += tensor(neighbor); + } + }); + } + +Performance Considerations +========================== + +MultiIndex is designed for zero-overhead abstraction (see :ref:`ck_tile_gpu_basics` for GPU performance fundamentals): + +1. **Compile-Time Resolution**: When dimensions are known at compile time, all operations are inlined +2. **Register Allocation**: Small fixed-size arrays typically stay in registers +3. **Vectorization**: Compiler can vectorize operations on MultiIndex arrays +4. **Memory Layout**: Contiguous storage enables efficient cache usage + +.. code-block:: cpp + + // Performance-optimized coordinate operations + template + struct OptimizedCoordOps { + // Fused multiply-add for linear indexing + __device__ __forceinline__ static index_t + compute_offset(const MultiIndex& coord, + const MultiIndex& strides) + { + index_t offset = 0; + + // Unroll for small dimensions + if constexpr (NDim <= 4) { + #pragma unroll + for (index_t i = 0; i < NDim; ++i) { + offset = __fma_rn(coord[i], strides[i], offset); + } + } else { + // Partial unrolling for larger dimensions + #pragma unroll 4 + for (index_t i = 0; i < NDim; ++i) { + offset += coord[i] * strides[i]; + } + } + + return offset; + } + }; + +Summary +======= + +MultiIndex is the foundation of CK Tile's coordinate system: + +- **Simple Abstraction**: Container for N integers representing position +- **Universal Usage**: Every transform and adaptor operates on MultiIndex +- **Type-Safe**: Compile-time size and bounds checking in C++ +- **Zero-Overhead**: Template metaprogramming ensures no runtime cost +- **Flexible**: Supports both compile-time and runtime coordinates + +Understanding MultiIndex is crucial before moving to transforms and adaptors, as they all build upon this fundamental coordinate representation. MultiIndex is the common language that allows all CK Tile components to work together seamlessly. + +For the complete picture of how MultiIndex fits into the CK Tile coordinate system, see :ref:`ck_tile_coordinate_systems`. For practical usage in tile distribution, see :ref:`ck_tile_tile_distribution`. diff --git a/docs/conceptual/ck_tile/tensor_views.rst b/docs/conceptual/ck_tile/tensor_views.rst new file mode 100644 index 0000000000..0c46e1e593 --- /dev/null +++ b/docs/conceptual/ck_tile/tensor_views.rst @@ -0,0 +1,482 @@ +.. _ck_tile_tensor_views: + +Tensor Views - Multi-Dimensional Structure +========================================== + +Overview +-------- + +While :ref:`BufferView ` provides the foundation for raw memory access, TensorView adds multi-dimensional structure to flat memory regions. This abstraction bridges the gap between how developers conceptualize data and how that data is physically stored in linear memory. TensorView enables coordinate-based access patterns that match the natural structure of algorithms while maintaining the performance characteristics necessary for efficient GPU computation. + +TensorView presents different logical views of the same underlying memory without copying data. A single memory region can be viewed as a row-major matrix, a column-major matrix, or a transposed matrix, using different TensorView configurations. This zero-copy abstraction enables flexible transformations and access patterns while maintaining optimal memory bandwidth utilization. + +TensorView Architecture +----------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Memory Foundation" + Memory["Flat Memory Array
0 1 2 3 4 5 6 7 8 9 10 11"] + end + + subgraph "Access Layer" + BufferView["BufferView
Linear Memory Access"] + Descriptor["TensorDescriptor
Shape & Stride Info"] + end + + subgraph "Tensor Layer" + TensorView["TensorView
Multi-dimensional Access"] + end + + subgraph "Logical View" + Matrix["2D Matrix View
[3×4]
[[0,1,2,3]
[4,5,6,7]
[8,9,10,11]]"] + end + + Memory --> BufferView + Memory --> Descriptor + BufferView --> TensorView + Descriptor --> TensorView + TensorView --> Matrix + + style Memory fill:#d1fae5,stroke:#10b981,stroke-width:2px + style BufferView fill:#dbeafe,stroke:#3b82f6,stroke-width:2px + style Descriptor fill:#fed7aa,stroke:#f59e0b,stroke-width:2px + style TensorView fill:#fce7f3,stroke:#ec4899,stroke-width:2px + style Matrix fill:#e9d5ff,stroke:#9333ea,stroke-width:2px + + + + + + +.. image:: diagrams/tensor_views_1.svg + :alt: Diagram + :align: center + +The Foundation: BufferView and TensorDescriptor +------------------------------------------------ + +TensorView builds upon two fundamental components that work in concert to provide structured access to memory. The :ref:`BufferView ` component handles the low-level memory access, providing type-safe operations with address space awareness. The :ref:`TensorDescriptor ` component encodes the multi-dimensional structure, including shape information and stride patterns that determine how coordinates map to memory offsets. + +This separation of concerns enables optimizations. The BufferView can optimize for the specific memory space while the TensorDescriptor can encode complex access patterns without concern for the underlying memory type. Together, they provide a complete abstraction for multi-dimensional data access. + +C++ Implementation +------------------ + +**File**: ``include/ck_tile/core/tensor/tensor_view.hpp`` + +Creating TensorViews +~~~~~~~~~~~~~~~~~~~~ + +The creation of a TensorView involves combining a BufferView with a TensorDescriptor. This process can be done explicitly for maximum control or through convenience functions for common patterns: + +.. code-block:: cpp + + #include + #include + #include + + // The actual C++ template signature from tensor_view.hpp: + // template + // struct tensor_view + + __device__ void example_tensor_creation() + { + // Create a 3x4 matrix in global memory + float data[12] = {0,1,2,3,4,5,6,7,8,9,10,11}; + + // Method 1: Create buffer and descriptor separately + auto buffer = make_buffer_view(data, 12); + auto desc = make_tensor_descriptor( + make_tuple(3, 4), // shape: 3 rows, 4 columns + make_tuple(4, 1) // strides: row stride=4, col stride=1 + ); + + // Create tensor view + auto tensor = make_tensor_view(buffer, desc); + + // Method 2: Use convenience function for packed layout + auto tensor2 = make_naive_tensor_view_packed( + data, // pointer + make_tuple(3, 4) // shape (strides calculated automatically) + ); + + // Access element at (1, 2) + float value = tensor(make_tuple(1, 2)); // Returns 6 + + // Update element + tensor(make_tuple(2, 1)) = 99.0f; + } + +Coordinate-Based Access +~~~~~~~~~~~~~~~~~~~~~~~ + +The fundamental operation of TensorView is translating multi-dimensional coordinates into memory accesses. This translation happens through an advanced pipeline that maintains efficiency while providing flexibility: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "User Input" + Coord["Coordinate
(1, 2)"] + end + + subgraph "TensorView Processing" + Shape["Shape Check
row < 3?
col < 4?"] + Stride["Apply Strides
offset = 1×4 + 2×1"] + Buffer["BufferView Access
buffer[6]"] + end + + subgraph "Result" + Value["Value: 6"] + end + + Coord --> Shape + Shape -->|Valid| Stride + Stride --> Buffer + Buffer --> Value + + style Coord fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style Shape fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + style Stride fill:#dcfce7,stroke:#10b981,stroke-width:2px + style Buffer fill:#dbeafe,stroke:#3b82f6,stroke-width:2px + style Value fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + + +.. image:: diagrams/tensor_views_2.svg + :alt: Diagram + :align: center + +Memory Layouts and Strides +-------------------------- + +A key feature of TensorView is its ability to represent different memory layouts through stride manipulation. This capability enables zero-copy transformations that would otherwise require expensive memory operations: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Row-Major Layout (C-style)" + RM["Memory: [0,1,2,3,4,5,6,7,8,9,10,11]
Shape: (3,4)
Strides: (4,1)"] + RMMatrix["[[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9, 10, 11]]"] + RM --> RMMatrix + end + + subgraph "Column-Major Layout (Fortran-style)" + CM["Memory: [0,3,6,9,1,4,7,10,2,5,8,11]
Shape: (3,4)
Strides: (1,3)"] + CMMatrix["[[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9, 10, 11]]"] + CM --> CMMatrix + end + + subgraph "Custom Stride (Transposed View)" + TV["Memory: [0,1,2,3,4,5,6,7,8,9,10,11]
Shape: (4,3)
Strides: (1,4)"] + TVMatrix["[[0, 4, 8]
[1, 5, 9]
[2, 6, 10]
[3, 7, 11]]"] + TV --> TVMatrix + end + + style RM fill:#e0f2fe,stroke:#0284c7,stroke-width:2px + style CM fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + style TV fill:#f3e8ff,stroke:#9333ea,stroke-width:2px + + + + + + +.. image:: diagrams/tensor_views_3.svg + :alt: Diagram + :align: center + +Row-Major vs Column-Major Layouts +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The choice of memory layout has profound implications for performance. Row-major layout, where consecutive elements in a row are stored contiguously, optimizes for row-wise traversal. Column-major layout optimizes for column-wise traversal. CK's TensorView abstraction allows algorithms to work with their natural access patterns regardless of the underlying storage: + +.. code-block:: cpp + + __device__ void example_memory_layouts() + { + float data[12] = {0,1,2,3,4,5,6,7,8,9,10,11}; + + // Row-major layout (default) + auto row_major = make_naive_tensor_view_packed( + data, make_tuple(3, 4) + ); + // Strides: (4, 1) - moving one row advances by 4 elements + + // Column-major layout through custom strides + auto col_major = make_tensor_view( + make_buffer_view(data, 12), + make_tensor_descriptor( + make_tuple(3, 4), // shape + make_tuple(1, 3) // strides: row stride=1, col stride=3 + ) + ); + + // Transposed view (no data copy!) + auto transposed = make_tensor_view( + make_buffer_view(data, 12), + make_tensor_descriptor( + make_tuple(4, 3), // transposed shape + make_tuple(1, 4) // transposed strides + ) + ); + + // All three views access the same memory, just differently + // row_major(1,2) == col_major(2,1) == transposed(2,1) + } + +Advanced Operations +------------------- + +Slicing and Subviews +~~~~~~~~~~~~~~~~~~~~ + +TensorView supports advanced slicing operations that create new views of subsets of the data. These operations are essential for algorithms that process data in blocks or tiles. See :ref:`ck_tile_tile_window` for production use. + +.. code-block:: cpp + + __device__ void example_slicing_operations() + { + // Create a larger tensor + float data[100]; + auto tensor = make_naive_tensor_view_packed( + data, make_tuple(10, 10) + ); + + // Create a subview using transforms + // This would typically be done with tile_window in production code + auto subview = make_tensor_view( + tensor.get_buffer_view(), + transform_tensor_descriptor( + tensor.get_tensor_descriptor(), + make_tuple( + make_pass_through_transform(number<5>{}), // 5 rows + make_pass_through_transform(number<5>{}) // 5 columns + ), + make_tuple(number<2>{}, number<3>{}) // offset (2,3) + ) + ); + + // subview now represents a 5x5 region starting at (2,3) + } + +Vectorized Access +~~~~~~~~~~~~~~~~~ + +GPUs achieve maximum memory bandwidth through vectorized operations. TensorView provides native support for vector loads and stores. See :ref:`ck_tile_load_store_traits` for more details. + +.. code-block:: cpp + + __device__ void example_vectorized_access() + { + float data[256]; + auto tensor = make_naive_tensor_view_packed( + data, make_tuple(16, 16) + ); + + // Create coordinate for vectorized access + auto coord = make_tensor_coordinate( + tensor.get_tensor_descriptor(), + make_tuple(4, 0) // row 4, starting at column 0 + ); + + // Load 4 consecutive elements as float4 + using float4 = vector_type::type; + auto vec4 = tensor.get_vectorized_elements(coord, 0); + + // Process vector data + vec4.x *= 2.0f; + vec4.y *= 2.0f; + vec4.z *= 2.0f; + vec4.w *= 2.0f; + + // Store back + tensor.set_vectorized_elements(coord, 0, vec4); + } + +Performance Considerations +-------------------------- + +Memory Access Patterns +~~~~~~~~~~~~~~~~~~~~~~ + +The efficiency of TensorView operations depends on memory access patterns. Understanding these patterns is important for achieving optimal performance. See :ref:`ck_tile_gpu_basics` for hardware considerations. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Memory Access Patterns" + Seq["Sequential Access
(Good cache usage)"] + Stride["Strided Access
(May cause cache misses)"] + Random["Random Access
(Poor cache usage)"] + end + + subgraph "Optimization Strategies" + Opt1["Use row-major for row iteration"] + Opt2["Use col-major for column iteration"] + Opt3["Minimize stride between accesses"] + Opt4["Vectorize when possible"] + end + + Seq --> Opt1 + Stride --> Opt2 + Stride --> Opt3 + Random --> Opt4 + + style Seq fill:#d1fae5,stroke:#10b981,stroke-width:2px + style Stride fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + style Random fill:#fee2e2,stroke:#ef4444,stroke-width:2px + + + + + + +.. image:: diagrams/tensor_views_4.svg + :alt: Diagram + :align: center + +Compile-Time Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~ + +CK's TensorView leverages compile-time optimization to achieve zero-overhead abstraction. When tensor dimensions and strides are known at compile time, the entire coordinate-to-offset calculation can be resolved during compilation: + +.. code-block:: cpp + + // Compile-time known dimensions enable optimization + constexpr auto shape = make_tuple(number<256>{}, number<256>{}); + constexpr auto strides = make_tuple(number<256>{}, number<1>{}); + + auto tensor = make_tensor_view( + buffer, + make_tensor_descriptor(shape, strides) + ); + + // This access compiles to a single memory instruction + constexpr auto coord = make_tuple(number<5>{}, number<10>{}); + auto value = tensor(coord); // Offset calculated at compile time + +TensorView vs BufferView +------------------------ + +Understanding when to use TensorView versus BufferView is crucial for writing efficient code: + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "BufferView" + BV1["Linear indexing only"] + BV2["buffer[5]"] + BV3["No shape information"] + BV4["Direct memory access"] + end + + subgraph "TensorView" + TV1["Multi-dimensional indexing"] + TV2["tensor(1, 2)"] + TV3["Shape-aware operations"] + TV4["Coordinate transformations"] + end + + subgraph "Use Cases" + UC1["BufferView: Low-level memory ops"] + UC2["TensorView: Matrix/tensor algorithms"] + end + + BV1 --> UC1 + TV1 --> UC2 + + style BV1 fill:#dbeafe,stroke:#3b82f6,stroke-width:2px + style TV1 fill:#fce7f3,stroke:#ec4899,stroke-width:2px + + + + + + +.. image:: diagrams/tensor_views_5.svg + :alt: Diagram + :align: center + +BufferView excels at raw memory operations where linear access is natural or where the overhead of coordinate calculation would be prohibitive. TensorView is best suited for algorithms that operate in terms of multi-dimensional coordinates, such as matrix operations, image processing, or tensor contractions. + +Integration with Tile Distribution +---------------------------------- + +TensorView serves as the foundation for :ref:`tile distribution's ` higher-level abstractions. When combined with :ref:`tile windows ` and distribution patterns, TensorView enables the automatic generation of efficient access patterns: + +.. code-block:: cpp + + // TensorView provides the base abstraction + auto tensor_view = make_naive_tensor_view_packed( + global_memory, make_tuple(M, N) + ); + + // Tile window builds on TensorView for distributed access + auto tile_window = make_tile_window( + tensor_view, + tile_shape, + origin, + distribution + ); + + // The distribution automatically generates optimal access patterns + auto distributed_tensor = tile_window.load(); + +Summary +------- + +TensorView bridges the gap between logical multi-dimensional data structures and physical memory layout. Through its advanced design, TensorView provides: + +**Multi-dimensional Indexing**: Natural coordinate-based access to data, matching how algorithms conceptualize their operations. This abstraction eliminates error-prone manual index calculations while maintaining performance. + +**Flexible Memory Layouts**: Support for row-major, column-major, and custom stride patterns enables algorithms to work with data in its most natural form. Zero-copy transformations like transposition become stride manipulations. + +**Zero-Copy Views**: The ability to create different logical views of the same physical memory enables flexible transformations without the overhead of data movement. This capability is essential for efficient GPU programming where memory bandwidth is often the limiting factor. + +**Type Safety**: Dimensions and memory spaces are encoded in the type system, catching errors at compile time rather than runtime. This safety comes without performance overhead thanks to template metaprogramming. + +**Seamless Integration**: TensorView works harmoniously with :ref:`BufferView ` for low-level access and serves as the foundation for higher-level abstractions like :ref:`tile windows ` and :ref:`distributed tensors `. + +The abstraction enables writing dimension-agnostic algorithms while maintaining high performance through compile-time optimizations. + +Next Steps +---------- + +Continue to :ref:`ck_tile_coordinate_systems` to understand the mathematical foundation of coordinate transformations in CK Tile. diff --git a/docs/conceptual/ck_tile/terminology.rst b/docs/conceptual/ck_tile/terminology.rst new file mode 100644 index 0000000000..7d5fc87fe9 --- /dev/null +++ b/docs/conceptual/ck_tile/terminology.rst @@ -0,0 +1,383 @@ +.. _ck_tile_terminology: + +Terminology Reference - Key Concepts and Definitions +==================================================== + +Overview +-------- + +The Composable Kernel framework introduces concepts and abstractions that form the foundation of its approach to high-performance GPU computing. This terminology reference serves as a comprehensive guide to the language of CK, providing detailed explanations of each term along with practical examples of their usage in C++ code. + +The terminology of CK reflects its layered architecture, with concepts building upon one another in a logical progression. From the fundamental notion of tiles and distributions to the compile-time coordinate transformation systems, each term represents a carefully designed abstraction that serves a specific purpose in the overall framework. This reference is organized to mirror this conceptual hierarchy, starting with core concepts and progressing through increasingly specialized terminology. + +As you explore this reference, you'll notice that many terms are interconnected, reflecting the holistic nature of the CK design. A tile is not just a block of data but a fundamental unit of work distribution. A distribution is not merely a pattern but a mathematical framework for optimal resource utilization. These interconnections are intentional and understanding them is crucial for effective use of the framework. + +Core Concepts +------------- + +Tile +~~~~ +The concept of a tile represents the fundamental unit of data organization in the CK framework. A tile is a contiguous block of data that is processed as a cohesive unit by a coordinated group of threads. This abstraction serves multiple critical purposes in achieving high performance on GPU architectures. By organizing data into tiles, the framework ensures that memory accesses exhibit spatial locality, enabling efficient use of cache hierarchies. The tile size is chosen to balance several competing factors: it must be large enough to amortize the overhead of memory transactions, yet small enough to fit within the limited on-chip memory resources. Furthermore, tiles are designed to align with the :ref:`GPU's execution model `, ensuring that threads within a warp access contiguous memory locations for optimal bandwidth utilization. + +**C++ Usage**: ``using TileShape = sequence<256, 256>;`` + +Distribution +~~~~~~~~~~~~ +The distribution pattern represents one of the most compile-time abstractions in the CK framework, defining the precise mapping between logical data elements and the physical processing resources that will operate on them. A distribution is far more than an assignment scheme—it embodies a strategy for achieving optimal performance on GPU hardware. The distribution determines which threads access which data elements, how those accesses are ordered to maximize memory bandwidth, and how intermediate results are shared between cooperating threads. By encoding these decisions at compile time, distributions enable the generation of highly optimized code that respects hardware constraints while maintaining algorithmic clarity. For a detailed exploration of distribution concepts, see :ref:`ck_tile_distribution`. + +**C++ Type**: ``tile_distribution<...>`` + +Encoding +~~~~~~~~ +An encoding in CK represents a compile-time specification that captures the strategy for distributing tensor data across GPU processing elements. This specification is not merely a configuration but a mathematical description of the transformation between coordinate spaces. The encoding defines the hierarchical decomposition of work, the mapping between thread indices and data elements, and the patterns by which threads cooperate to process their assigned data. By expressing these concepts as compile-time constants, encodings enable aggressive compiler optimizations while ensuring that distribution strategies can be verified for correctness before execution. + +**C++ Type**: ``tile_distribution_encoding<...>`` + +Coordinate Spaces +----------------- + +For a comprehensive mathematical treatment of coordinate systems, see :ref:`ck_tile_coordinate_systems`. + +P-Space (Partition Space) +~~~~~~~~~~~~~~~~~~~~~~~~~ +The Partition Space, or P-space, represents the fundamental abstraction for identifying processing elements within the GPU's execution hierarchy. This coordinate space captures the multi-level organization of GPU computation, from individual threads to warps to thread blocks. P-space typically manifests as either a one-dimensional space containing only lane identifiers for simple distributions, or a two-dimensional space incorporating both warp and lane identifiers for more complex hierarchical distributions. The significance of P-space extends beyond mere thread identification—it forms the foundation for all work distribution decisions, determining which processing elements will collaborate on specific data tiles and how they will coordinate their efforts. + +The dimensions of P-space directly reflect the hardware's execution model. In a one-dimensional P-space, threads are identified solely by their lane ID within a warp, suitable for algorithms where inter-warp coordination is minimal. Two-dimensional P-space adds warp-level coordination, enabling advanced tiling strategies that leverage both intra-warp and inter-warp parallelism. The values in P-space are always hardware thread indices, providing a direct mapping to the physical execution resources. + +**C++ Example**: + +.. code-block:: cpp + + // Get current thread's P coordinates + auto p_idx = Distribution::_get_partition_index(); + +Y-Space (Yield Space) +~~~~~~~~~~~~~~~~~~~~~ +The Yield Space, or Y-space, embodies the logical structure of computation within each tile, representing the pattern by which threads traverse their assigned data. Unlike P-space which identifies threads, Y-space defines what each thread does with its assigned work. This abstraction enables the expression of complex access patterns—from simple linear traversals to advanced space-filling curves—in a hardware-independent manner. The dimensionality of Y-space varies with the algorithm's requirements, typically ranging from two dimensions for matrix operations to four or more for complex tensor contractions. + +Y-space serves as the primary iteration space for computational kernels. When a thread processes its assigned tile, it iterates through Y-space coordinates, with each coordinate mapping to specific data elements within the tile. This abstraction enables critical optimizations: the Y-space traversal order can be designed to maximize data reuse, minimize register pressure, or optimize for specific hardware characteristics, all without changing the fundamental algorithm. + +**C++ Example**: + +.. code-block:: cpp + + // Iterate over Y-space + sweep_tile(tensor, [](auto y_idx) { /*...*/ }); + +X-Space (Physical Tensor Space) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The Physical Tensor Space, or X-space, represents the ground truth of data organization—the actual coordinates within the global tensor. This space directly corresponds to how data is laid out in memory, with dimensions matching those of the tensor being processed. For a matrix, X-space is two-dimensional with row and column coordinates. For a 4D convolution tensor, X-space encompasses batch, channel, height, and width dimensions. X-space serves as the target of the coordinate transformation pipeline, where abstract thread and pattern coordinates are converted into concrete memory addresses. + +The relationship between X-space and physical memory is direct but not necessarily trivial. While X-space coordinates identify logical positions within a tensor, the actual memory layout may involve padding, striding, or other transformations for alignment and performance. The CK framework handles these low-level details transparently, allowing algorithms to work with logical X-space coordinates while ensuring efficient physical memory access. + +**C++ Example**: + +.. code-block:: cpp + + // Calculate X coordinates from P+Y + auto x_idx = distribution.calculate_index(p_idx); + +R-Space (Replication Space) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The Replication Space, or R-space, introduces a advanced mechanism for expressing redundant computation patterns that enhance performance through data sharing. Unlike the other coordinate spaces which map to unique data elements, R-space enables multiple processing elements to compute the same values, facilitating efficient communication patterns. This replication serves multiple purposes: it can reduce global memory traffic by computing values locally rather than loading them, enable efficient reduction operations by providing private workspace for each thread group, and facilitate complex data exchange patterns that would otherwise require expensive synchronization. + +R-space dimensions are optional and algorithm-specific. A matrix multiplication might use R-space to replicate portions of the input matrices across thread groups, enabling each group to compute partial products independently. The framework automatically manages the complexities of replication, including the allocation of private storage and the coordination of replicated computations. + +**C++ Example**: + +.. code-block:: cpp + + // R-dimensions in encoding + using Encoding = tile_distribution_encoding< + sequence<2>, // rs_lengths: 2-way replication + /*...*/ + >; + +D-Space (Data Space) +~~~~~~~~~~~~~~~~~~~~ +The Data Space, or D-space, represents the final stage of the coordinate transformation pipeline—the linearization of multi-dimensional tile data for efficient storage in thread-local registers. This one-dimensional space serves a critical role in managing the GPU's most precious resource: register files. By transforming the potentially complex Y-space coordinates into a linear D-space index, the framework enables efficient register allocation and access patterns that minimize register bank conflicts and maximize instruction-level parallelism. + +The transformation from Y-space to D-space is more than a simple flattening operation. It incorporates optimized strategies for register layout that consider the GPU's register file organization, the kernel's register pressure, and the access patterns of the computation. This transformation ensures that frequently accessed elements are kept in registers, that register bank conflicts are minimized, and that the compiler can generate efficient code for register access. + +**C++ Example**: + +.. code-block:: cpp + + // Y-to-D descriptor linearizes storage + auto d_idx = ys_to_d_descriptor.calculate_offset(y_idx); + +Dimension Types +--------------- + +H-Dimensions (Hierarchical Dimensions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The concept of Hierarchical Dimensions, or H-dimensions, represents one of the most key aspects of the CK framework's approach to work distribution. These dimensions encode a multi-level decomposition strategy that mirrors the hierarchical nature of GPU hardware, from individual vector operations up through threads, warps, and thread blocks. Each H-dimension group captures how a single tensor dimension is partitioned across these hardware levels, enabling fine-grained control over data access patterns and computational efficiency. + +The structure of H-dimensions follows a specific pattern that reflects the GPU's execution hierarchy. Each H-dimension is expressed as a sequence of factors, where each factor corresponds to a specific level of the hierarchy. Consider the example ``sequence<4, 2, 8, 4>``. This seemingly simple sequence encodes a advanced distribution strategy: the rightmost factor (4) represents vector width, indicating that each memory operation processes 4 elements simultaneously. Moving left, the factor 8 indicates that 8 threads within a warp collaborate on the data. The factor 2 specifies that 2 warps within a block work together. Finally, the leftmost factor 4 indicates that each thread performs 4 iterations, enabling instruction-level parallelism and register reuse. + +This hierarchical decomposition enables critical optimizations. By explicitly encoding the distribution strategy at compile time, the framework can generate code that perfectly matches the hardware's capabilities. The vector width aligns with the GPU's memory transaction size. The thread count per warp matches the hardware's SIMD width. The warp count per block balances parallelism with resource constraints. The repetition factor enables loop unrolling and software pipelining. Together, these factors create a distribution strategy that achieves near-optimal performance. + +**C++ Example**: + +.. code-block:: cpp + + using HsLengthss = tuple< + sequence<4, 2, 8, 4>, // H0: M dimension + sequence<4, 2, 8, 4> // H1: N dimension + >; + +RH-Dimensions (R + H Dimensions Combined) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The RH-dimensions represent the unified coordinate space that combines both replication (R) and hierarchical (H) dimensions into a single, coherent framework. This combined space serves as the internal representation used by the coordinate transformation machinery, enabling seamless handling of both replicated and non-replicated data patterns. The unification of these dimensions simplifies the mathematical framework while maintaining the flexibility to express complex distribution strategies. + +Within the RH-dimension framework, coordinates are identified by two components: major and minor indices. The major index identifies which dimension group a coordinate belongs to, with 0 reserved for R-dimensions and subsequent values (1, 2, ...) identifying H-dimension groups. The minor index specifies the position within the identified group. This two-level addressing scheme enables efficient navigation through the combined coordinate space while maintaining clear separation between replication and hierarchical decomposition strategies. + +The power of RH-dimensions becomes apparent when considering complex algorithms that require both data replication and hierarchical distribution. By providing a unified coordinate system, the framework can express transformations that simultaneously handle replicated data sharing and hierarchical work distribution, all within a single mathematical formalism. This unification is key to achieving both expressiveness and efficiency in the CK framework. + +Transformations +--------------- + +Adaptor +~~~~~~~ +An adaptor in the CK framework represents a advanced chain of coordinate transformations that bridges different coordinate spaces. Rather than simple one-to-one mappings, adaptors embody complex mathematical transformations that can involve permutations, embeddings, projections, and non-linear mappings. These transformations are composed at compile time, enabling the generation of highly optimized code that performs the complete transformation in a single step without intermediate representations. For detailed information about adaptors and their implementation, see :ref:`ck_tile_adaptors`. + +The framework provides several specialized adaptor types, each serving a specific role in the coordinate transformation pipeline. The ``ps_ys_to_xs_adaptor`` performs the critical transformation from processing element and yield space coordinates to physical tensor coordinates, implementing the core logic of tile distribution. This adaptor encodes decisions about how threads are assigned to data, how data is traversed within each thread's assignment, and how these patterns map to the global tensor layout. Similarly, the ``ys_to_d_adaptor`` handles the transformation from multi-dimensional yield space to linearized data space, optimizing the layout of data in thread-local registers. + +The power of adaptors lies in their composability. Complex transformations can be built by chaining simpler adaptors, with the framework automatically optimizing the composition. This design enables the expression of advanced access patterns—such as transposed access, strided access, or space-filling curves—through the composition of elementary transformations. The compile-time nature of this composition ensures zero runtime overhead while maintaining mathematical clarity. + +**C++ Type**: ``tensor_adaptor<...>`` + +Descriptor +~~~~~~~~~~ +A descriptor in CK provides a complete specification of tensor layout, encompassing not just the logical structure of the data but also all transformations and physical memory layout details. This comprehensive specification serves as the contract between different components of the system, ensuring that all parts of a kernel have a consistent view of how data is organized and accessed. Descriptors combine multiple aspects of tensor representation: the logical shape and dimensions, the physical memory layout including padding and alignment, the coordinate transformations for different access patterns, and optimization hints for the compiler. For comprehensive coverage of descriptors, see :ref:`ck_tile_descriptors`. + +The sophistication of descriptors enables them to represent complex data layouts that arise in real-world applications. A descriptor might specify that a logically 4D tensor is physically stored with padding for alignment, uses a custom stride pattern for the channel dimension, and should be accessed using a space-filling curve for optimal cache utilization. All these details are encoded in the descriptor's type, enabling compile-time verification and optimization. + +Descriptors play a crucial role in achieving performance portability. By abstracting the details of data layout behind a well-defined interface, descriptors enable algorithms to be written once and automatically adapted to different data layouts. This abstraction is particularly valuable when dealing with different hardware architectures that may have different alignment requirements, cache line sizes, or memory access patterns. + +**C++ Type**: ``tensor_descriptor<...>`` + +Operations +---------- + +Load Tile +~~~~~~~~~ +The load tile operation represents a fundamental building block of GPU kernel design in the CK framework, orchestrating the complex process of transferring data from global memory to thread-local registers. This operation is far more advanced than a simple memory copy—it implements the complete distribution strategy encoded in the tile distribution, ensuring that each thread loads exactly the data it needs for its portion of the computation. The load operation automatically handles memory coalescing to maximize bandwidth utilization, coordinates between threads to avoid redundant loads, manages boundary conditions for tiles that extend beyond tensor bounds, and optimizes the access pattern based on the specific distribution strategy. + +The efficiency of the load tile operation stems from its deep integration with the distribution framework. By knowing at compile time exactly which threads will access which data elements, the operation can generate optimal memory access patterns that fully utilize the GPU's memory subsystem. For matrix multiplication, this might mean loading data in a pattern that ensures perfect coalescing. For convolution, it might involve complex patterns that minimize the number of redundant loads while respecting the GPU's cache hierarchy. + +**C++ Function**: ``tile_window.load()`` + +Store Tile +~~~~~~~~~~ +The store tile operation provides the complementary functionality to load tile, transferring computed results from thread-local registers back to global memory. Like its counterpart, the store operation implements optimized strategies that go beyond simple memory writes. It ensures that writes are coalesced for maximum bandwidth efficiency, coordinates between threads to handle overlapping write regions correctly, manages atomic operations when multiple threads write to the same location, and optimizes write patterns to minimize memory traffic. + +The store operation must handle additional complexities compared to loads. While loads can often ignore synchronization issues (reading stale data is usually harmless), stores must ensure correctness when multiple threads write to overlapping regions. The framework provides different store modes for different scenarios: exclusive stores where each element is written by exactly one thread, atomic stores where multiple threads may update the same element, and reduction stores where partial results are accumulated. The choice of store mode is encoded in the distribution strategy and verified at compile time. + +**C++ Function**: ``tile_window.store(tile)`` + +Sweep Tile +~~~~~~~~~~ +The sweep tile operation embodies a key programming paradigm for distributed tensor computation, providing a high-level iteration abstraction over the complex distribution patterns. Rather than requiring manual index calculations and nested loops, sweep tile automatically visits each element in a distributed tensor exactly once, invoking a user-provided function with the appropriate coordinates. This abstraction hides the complexity of the distribution while enabling advanced optimizations such as automatic loop unrolling, software pipelining, and register rotation. + +The implementation of sweep tile leverages the compile-time knowledge of the distribution pattern to generate highly optimized iteration code. For simple distributions, this might result in a single unrolled loop. For complex hierarchical distributions, it might generate nested loops with carefully chosen iteration orders that maximize data reuse and minimize register pressure. The beauty of the abstraction is that these optimizations happen transparently—the user simply provides the computation to perform on each element, and the framework handles the rest. + +**C++ Function**: ``sweep_tile(tensor, lambda)`` + +Shuffle Tile +~~~~~~~~~~~~ +The shuffle tile operation provides efficient intra-warp communication, enabling threads within a warp to exchange data without going through shared memory. This operation leverages the GPU's hardware shuffle instructions, which allow any thread in a warp to read registers from any other thread in the same warp. Shuffle operations are particularly valuable for reduction operations, transpose operations within a warp, and collaborative loading patterns where threads cooperate to load contiguous data and then redistribute it according to the computation pattern. + +The framework provides various shuffle patterns optimized for different use cases. Butterfly shuffles enable efficient reductions and FFT-like operations. Broadcast shuffles allow one thread to share data with all others in the warp. Rotation shuffles enable cyclic data exchange patterns. The shuffle tile operation automatically selects the appropriate hardware instructions based on the data type and shuffle pattern, ensuring optimal performance while maintaining portability across different GPU architectures. + +**C++ Function**: ``shuffle_tile(tensor, shuffle_pattern)`` + +Memory Concepts +--------------- + +Coalescing +~~~~~~~~~~ +The property where adjacent threads access adjacent memory locations, maximizing memory bandwidth utilization. + +Bank Conflict +~~~~~~~~~~~~~ +A performance degradation that occurs when multiple threads in a warp access different addresses in the same memory bank. For detailed information about bank conflicts and mitigation strategies, see :ref:`ck_tile_lds_bank_conflicts`. + +Vectorization +~~~~~~~~~~~~~ +The technique of loading/storing multiple elements in a single memory transaction. + +**C++ Example**: + +.. code-block:: cpp + + // Vector load of 4 elements + using float4 = vector_type::type; + float4 data = tensor_view.template get_vectorized_elements<4>(x_idx); + +Distribution Components +----------------------- + +Window +~~~~~~ +A view into a subset of a tensor that respects the distribution pattern. For detailed information about tile windows and their usage, see :ref:`ck_tile_tile_window`. + +**C++ Type**: ``tile_window<...>`` + +Static Distributed Tensor +~~~~~~~~~~~~~~~~~~~~~~~~~ +A thread-local tensor stored in registers, distributed according to a tile distribution. For in-depth coverage of static distributed tensors, see :ref:`ck_tile_static_distributed_tensor`. + +**C++ Type**: ``static_distributed_tensor<...>`` + +Spans +~~~~~ +Iteration ranges over distributed dimensions, used by sweep operations. + +**C++ Type**: ``tile_distributed_span<...>`` + +GPU Hardware Terms +------------------ + +Warp +~~~~ +A group of threads (32 on AMD GPUs) that execute in lockstep. + +Lane +~~~~ +An individual thread within a warp (0-31). + +Block +~~~~~ +A group of warps that can cooperate through shared memory. + +Grid +~~~~ +The complete set of blocks launched for a kernel. + +Template Parameters +------------------- + +sequence<...> +~~~~~~~~~~~~~ +A compile-time integer sequence used to specify dimensions and lengths. + +**Example**: ``sequence<256, 256>`` for a 256×256 tile + +tuple<...> +~~~~~~~~~~ +A heterogeneous collection of types, often used for grouping sequences. + +**Example**: ``tuple, sequence<4,4>>`` + +number +~~~~~~~~~ +A compile-time integer constant. + +**Example**: ``number<16>`` represents the value 16 + +Optimization Terms +------------------ + +Register Spilling +~~~~~~~~~~~~~~~~~ +When a kernel uses more registers than available, causing data to spill to slower memory. + +Occupancy +~~~~~~~~~ +The ratio of active warps to maximum possible warps on a GPU multiprocessor. + +Memory Bandwidth Utilization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The percentage of theoretical memory bandwidth achieved by a kernel. + +Instruction-Level Parallelism (ILP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ability to execute multiple independent instructions simultaneously. + +Common Patterns +--------------- + +GEMM (General Matrix Multiplication) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +A fundamental operation where C = αA×B + βC. For a complete optimization case study, see :ref:`ck_tile_gemm_optimization`. + +Reduction +~~~~~~~~~ +An operation that combines multiple values into a single result (e.g., sum, max). + +Broadcast +~~~~~~~~~ +An operation that replicates a value across multiple processing elements. + +Transpose +~~~~~~~~~ +An operation that swaps dimensions of a tensor. + +Performance Metrics +------------------- + +FLOPS (Floating-Point Operations Per Second) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Measure of computational throughput. + +Bandwidth +~~~~~~~~~ +Rate of data transfer, typically measured in GB/s. + +Latency +~~~~~~~ +Time delay between issuing an operation and its completion. + +Throughput +~~~~~~~~~~ +Rate of operation completion, often measured in operations per second. + +Usage Examples +-------------- + +Creating a Distribution +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Define encoding + using MyEncoding = tile_distribution_encoding< + sequence<>, // No replication + tuple, // M dimension + sequence<4,2,8,4>>, // N dimension + tuple, sequence<1,2>>, // P mappings + tuple, sequence<2,2>>, // P minor + sequence<1,1,2,2>, // Y major + sequence<0,3,0,3> // Y minor + >; + + // Create distribution + auto distribution = make_static_tile_distribution(MyEncoding{}); + +Using Tile Window +~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // Create window + auto window = make_tile_window( + tensor_view, + TileShape{}, + origin, + distribution + ); + + // Load-compute-store pattern + auto tile = window.load(); + sweep_tile(tile, compute_func); + window.store(tile); + +Related Documentation +--------------------- + +- :ref:`ck_tile_introduction` - Introduction and motivation +- :ref:`ck_tile_buffer_views` - Raw memory access +- :ref:`ck_tile_distribution` - Core distribution concepts + + diff --git a/docs/conceptual/ck_tile/thread_mapping.rst b/docs/conceptual/ck_tile/thread_mapping.rst new file mode 100644 index 0000000000..cff4f727ff --- /dev/null +++ b/docs/conceptual/ck_tile/thread_mapping.rst @@ -0,0 +1,551 @@ +.. meta:: + :description: CK Tile thread mapping - connecting mathematical abstractions to GPU hardware + :keywords: CDNA, RDNA, ROCm, CK, Composable Kernel, thread mapping, GPU programming + +.. _ck_tile_thread_mapping: + +******************************************************************** +Thread Mapping - Connecting to Hardware +******************************************************************** + +This section explains how threads get their unique IDs and how those map to specific data, and connecting mathematical abstractions to physical hardware. + +Thread mapping is the bridge between the mathematical abstraction and the physical hardware that executes the code. Thread mapping works closely with :ref:`ck_tile_tile_distribution` to ensure optimal performance. + +Thread Identification and Partition Indices +=========================================== + +Before threads can process data, they need to know who they are and what work they're responsible for. + +Hardware Thread Identification +------------------------------ + +In GPU hardware, threads are organized hierarchically: + +.. code-block:: cpp + + // CUDA/HIP thread identification + __device__ void get_thread_coordinates() + { + // Grid-level coordinates (which block) + int block_x = blockIdx.x; + int block_y = blockIdx.y; + int block_z = blockIdx.z; + + // Block-level coordinates (which thread in block) + int thread_x = threadIdx.x; + int thread_y = threadIdx.y; + int thread_z = threadIdx.z; + + // Warp identification + int warp_id = threadIdx.x / 32; // 32 threads per warp + int lane_id = threadIdx.x % 32; // Position within warp + + // Global thread ID calculation + int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + } + +C++ Thread Mapping in CK +------------------------ + +Composable Kernel abstracts thread identification into partition indices, building on the :ref:`ck_tile_coordinate_systems` foundation: + +.. code-block:: cpp + + // From tile_partition.hpp + template + struct tile_partition + { + CK_TILE_DEVICE static constexpr index_t get_thread_idx() + { + return threadIdx.x; + } + + CK_TILE_DEVICE static constexpr index_t get_block_idx() + { + return blockIdx.x; + } + + // Convert to multi-dimensional partition index + template + CK_TILE_DEVICE static constexpr auto get_partition_index() + { + constexpr auto thread_layout = ThreadLayout{}; + + // Convert linear thread ID to multi-dimensional index + return thread_layout.template get_index(get_thread_idx()); + } + }; + + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "GPU Device" + subgraph "Thread Block" + subgraph "Warp 0" + T0["Thread 0
lane_id=0"] + T1["Thread 1
lane_id=1"] + T2["..."] + T31["Thread 31
lane_id=31"] + end + + subgraph "Warp 1" + T32["Thread 32
lane_id=0"] + T33["Thread 33
lane_id=1"] + T34["..."] + T63["Thread 63
lane_id=31"] + end + + W2["Warp 2"] + W3["..."] + W7["Warp 7"] + end + end + + subgraph "Thread Identification" + TID["Thread ID = blockIdx.x * blockDim.x + threadIdx.x"] + WID["Warp ID = threadIdx.x / 32"] + LID["Lane ID = threadIdx.x % 32"] + end + + subgraph "P-space Mapping" + P["P-coordinates
NDimP=1: [thread_id]
NDimP=2: [warp_id, lane_id]"] + end + + T0 --> TID + TID --> WID + TID --> LID + WID --> P + LID --> P + + style T0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style T32 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style P fill:#fff3e0,stroke:#f57c00,stroke-width:3px + + + + + +.. image:: diagrams/thread_mapping_1.svg + :alt: Diagram + :align: center + + +Thread Hierarchy Structure +-------------------------- + +The hardware organizes threads in a specific hierarchy. See :ref:`ck_tile_gpu_basics` for hardware details. + +**Block Level**: Groups of warps working together + +- Warps per block defined by encoding, for example, 2×2 warps +- Shared memory and synchronization scope +- Block-level coordination possible + +**Warp Level**: Groups of threads executing in lockstep + +- Threads per warp defined by encoding, for example, 8×8 threads +- SIMD execution (all threads execute same instruction) +- Warp-level primitives (shuffle, vote, etc.) + +**Thread Level**: Individual execution units + +- Vector size per thread, for example, 4×4 elements +- Independent register space +- Vector operations on multiple elements + +Thread ID Mapping +----------------- + +Each thread gets a unique ID that maps to its position in the hierarchy. For example, in an RMSNorm configuration: + +- **Repeat (M, N)**: (4, 4) - Number of iterations +- **Warps per block (M, N)**: (2, 2) - 4 warps total +- **Threads per warp (M, N)**: (8, 8) - 64 threads per warp +- **Vector size (M, N)**: (4, 4) - 16 elements per thread + +This gives us: + +- **Threads per block**: 256 (4 warps × 64 threads/warp) +- **Elements per thread**: 16 (4×4 vector) +- **Total elements**: 4096 per block + +Thread-to-Data Mapping +====================== + +Once threads know their IDs, they need to map those IDs to specific data elements. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Thread to Data Mapping" + subgraph "Thread Grid" + T00["Thread[0,0]
Warp 0"] + T01["Thread[0,1]
Warp 0"] + T10["Thread[1,0]
Warp 1"] + T11["Thread[1,1]
Warp 1"] + end + + subgraph "Data Tiles" + D00["Data[0:4, 0:4]
16 elements"] + D01["Data[0:4, 4:8]
16 elements"] + D10["Data[4:8, 0:4]
16 elements"] + D11["Data[4:8, 4:8]
16 elements"] + end + + subgraph "Memory Access" + MA["Coalesced Access
Adjacent threads → Adjacent memory"] + end + end + + T00 --> D00 + T01 --> D01 + T10 --> D10 + T11 --> D11 + + D00 --> MA + D01 --> MA + + style T00 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style D00 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style MA fill:#fff3e0,stroke:#f57c00,stroke-width:2px + + + + + +.. image:: diagrams/thread_mapping_2.svg + :alt: Diagram + :align: center + +Data Distribution Pattern +------------------------- + +The RMSNorm operation distributes tensor data across threads in a structured pattern: + +**Hierarchical Data Distribution:** + +- **Block Level**: Multiple iterations (repeat factor) +- **Warp Level**: Warps process different regions +- **Thread Level**: Threads within warp handle adjacent data +- **Vector Level**: Each thread processes multiple elements + +Thread Work Assignment +---------------------- + +Each thread is assigned a specific rectangular region of the tensor. For example: + +- Thread in Warp[0,0] Thread[0,0] might process: + + - Data region (M): [0:4) + - Data region (N): [0:4) + - Total elements: 16 + +- Thread in Warp[0,0] Thread[0,1] might process: + + - Data region (M): [0:4) + - Data region (N): [4:8) + - Total elements: 16 + +This pattern ensures adjacent threads access adjacent memory for optimal coalescing. The :ref:`ck_tile_load_store_traits` system further optimizes these access patterns. + +Thread Cooperation Patterns +=========================== + +Threads don't work in isolation. Threads cooperate at different levels to achieve optimal performance. + +Warp-Level Cooperation +---------------------- + +Threads within a warp execute in lockstep (SIMD): + +- **Synchronization**: Automatic SIMD execution +- **Data sharing**: Warp shuffle instructions +- **Collective ops**: Warp-level reductions +- **Memory access**: Coalesced patterns + +Block-Level Cooperation +----------------------- + +Threads within a block can share data and synchronize: + +- **Shared memory**: All threads in block can access (see :ref:`ck_tile_lds_bank_conflicts` for optimization) +- **Synchronization**: ``__syncthreads()`` barriers +- **Data exchange**: Through shared memory +- **Collective operations**: Block-wide reductions + +Vector-Level Processing +----------------------- + +Each thread processes multiple elements: + +- **Register efficiency**: Multiple elements in registers +- **Memory coalescing**: Vectorized loads/stores +- **Instruction efficiency**: SIMD operations on vectors +- **Bandwidth utilization**: Maximum memory throughput + +Memory Access Patterns +====================== + +The thread mapping directly affects memory access. + +C++ Implementation of Memory Access +----------------------------------- + +Here's how CK implements memory access patterns: + +.. code-block:: cpp + + // Coalesced memory access pattern + template + __device__ void coalesced_load(const DataType* __restrict__ src, + DataType* __restrict__ dst, + index_t tid) + { + // Each thread loads VectorSize elements + // Adjacent threads access adjacent memory + constexpr index_t stride = blockDim.x; + + // Vectorized load for efficiency + using vector_t = vector_type_t; + + // Calculate aligned address + const vector_t* src_vec = reinterpret_cast( + src + tid * VectorSize); + + // Single vectorized load instruction + vector_t data = *src_vec; + + // Store to registers + reinterpret_cast(dst)[0] = data; + } + + // CK's distributed tensor load implementation + template + __device__ void load_tile_window(DistributedTensor& dist_tensor, + const auto& tile_window) + { + // Get thread's partition index + constexpr auto partition = tile_partition::get_partition_index(); + + // Each thread loads its assigned data + tile_window.load(dist_tensor, partition); + + // Hardware automatically coalesces adjacent thread accesses + } + +Memory Access Optimization Techniques +------------------------------------- + +CK uses several techniques to optimize memory access: + +.. code-block:: cpp + + // 1. Vector loads for maximum bandwidth + template + using vector_load_t = conditional_t>>; + + // 2. Swizzling to avoid bank conflicts + // See :ref:`ck_tile_lds_index_swapping` and :ref:`ck_tile_swizzling_example` + template + __device__ index_t swizzle_offset(index_t tid, index_t offset) + { + // Rotate access pattern to avoid conflicts + return (offset + (tid / BankSize)) % BankSize; + } + + // 3. Prefetching for latency hiding + __device__ void prefetch_next_tile(const float* src, index_t offset) + { + // Prefetch to L2 cache + __builtin_prefetch(src + offset, 0, 3); + } + +Memory Efficiency Benefits +-------------------------- + +The structured thread mapping provides several memory efficiency benefits: + +**Memory Coalescing Benefits:** + +- **Adjacent access**: Threads in same warp access adjacent memory locations +- **Cache efficiency**: Related data loaded together into cache lines +- **Bandwidth utilization**: Maximum memory bandwidth achieved +- **Reduced latency**: Fewer memory transactions needed + +**Performance Characteristics:** + +- **Predictable patterns**: Access patterns known at compile time +- **Vectorization**: Hardware can optimize vector operations +- **Reduced overhead**: No complex address calculations at runtime +- **Scalability**: Pattern scales efficiently with thread count + +Practical Thread Mapping Example +================================ + +Complete C++ Kernel Example +--------------------------- + +The following example shows how thread mapping works in a CK kernel: + +.. code-block:: cpp + + // RMSNorm kernel using CK's thread mapping + template + __global__ void rmsnorm_kernel(const DataType* __restrict__ x, + DataType* __restrict__ y, + const DataType* __restrict__ weight, + ComputeType epsilon, + index_t hidden_size) + { + // 1. Thread identification + const index_t tid = threadIdx.x; + const index_t bid = blockIdx.x; + + // 2. Create tile distribution encoding + // This would be defined based on your specific RMSNorm pattern + using Encoding = tile_distribution_encoding< + sequence<>, // No replication + tuple, sequence<4, 2>>, // H dimensions + tuple, sequence<2>>, // P to RH major + tuple, sequence<0>>, // P to RH minor + sequence<1, 2>, // Y to RH major + sequence<0, 0> // Y to RH minor + >; + constexpr auto tile_dist = make_static_tile_distribution(Encoding{}); + + // 3. Get thread's partition index from distribution + const auto partition_idx = tile_dist._get_partition_index(); + + // 4. Shared memory for reduction + __shared__ ComputeType shared_sum[BlockSize]; + + // 5. Create tensor view and tile window + // See :ref:`ck_tile_tensor_views` and :ref:`ck_tile_tile_window` + auto x_view = make_naive_tensor_view( + x + bid * hidden_size, + make_tuple(hidden_size), + make_tuple(number<1>{}) + ); + + auto x_window = make_tile_window( + x_view, + make_tuple(hidden_size), + make_tuple(number<0>{}), + tile_dist); + + // 6. Each thread processes its assigned elements + ComputeType thread_sum = 0; + static_for<0, VectorSize, 1>{}([&](auto i) { + // Access pattern would depend on your tile window setup + // This is conceptual - actual implementation varies + thread_sum += val * val; + }); + + // 7. Warp-level reduction + thread_sum = warp_reduce_sum(thread_sum); + + // 8. Block-level reduction + if (tid % WarpSize == 0) { + shared_sum[tid / WarpSize] = thread_sum; + } + __syncthreads(); + + // 9. Final reduction by first warp + if (tid < BlockSize / WarpSize) { + thread_sum = shared_sum[tid]; + thread_sum = warp_reduce_sum(thread_sum); + } + + // 10. Compute RMS and normalize + if (tid == 0) { + shared_sum[0] = rsqrt(thread_sum / hidden_size + epsilon); + } + __syncthreads(); + + const ComputeType rms_recip = shared_sum[0]; + + // 11. Write normalized output + auto y_window = make_tile_window( + make_tensor_view(y + bid * hidden_size), + tile_dist); + + static_for<0, VectorSize, 1>{}([&](auto i) { + auto idx = tile_dist.get_tensor_coordinate(partition_idx, i); + ComputeType val = static_cast(x_window.get(idx)); + ComputeType w = static_cast(weight[idx[1]]); + y_window.set(idx, static_cast(val * rms_recip * w)); + }); + } + +Key Thread Mapping Concepts in Action +------------------------------------- + +1. **Thread-to-Data Assignment**: Each thread gets a unique ``partition_idx`` +2. **Vectorized Access**: Each thread processes ``VectorSize`` elements +3. **Warp Cooperation**: Threads within a warp perform reductions +4. **Block Synchronization**: All threads synchronize for final result +5. **Coalesced Memory**: Adjacent threads access adjacent memory + +Key Takeaways +============= + +Thread mapping is the bridge between mathematical abstractions and physical hardware execution: + +**Thread Identification:** + +1. **Hierarchical Organization**: Threads organized in blocks → warps → threads → vectors + + - Each level has specific cooperation capabilities + - Hardware provides efficient primitives at each level + - Thread IDs map directly to data regions + - Predictable and efficient execution patterns + +2. **Data Assignment**: Each thread gets a specific rectangular region + + - Work distributed evenly across threads + - Memory access patterns optimized for coalescing + - Vector operations maximize throughput + - Scalable across different hardware configurations + +3. **Cooperation Patterns**: Threads cooperate at multiple levels + + - Warp-level SIMD execution for efficiency + - Block-level shared memory and synchronization + - Vector-level processing for maximum throughput + - Hierarchical coordination for complex operations + +**Performance Benefits:** + +- **Memory Coalescing**: Adjacent threads access adjacent memory for optimal bandwidth +- **Cache Efficiency**: Related data loaded together, reducing memory latency +- **Vectorization**: Hardware can optimize multiple operations per thread +- **Predictable Patterns**: Compile-time optimization of access patterns + +**Why This Matters:** + +Thread mapping connects encodings, transformations, and distributions to hardware execution. + +The RMSNorm example shows how a real operation uses these concepts to achieve optimal performance on GPU hardware. Every thread knows exactly what data to process, how to access it efficiently, and how to cooperate with other threads. + + +Related Topics + +- :ref:`ck_tile_descriptors` - Complete tensor specifications that thread mapping uses +- :ref:`ck_tile_coordinate_movement` - Advanced coordinate operations for thread navigation +- :ref:`ck_tile_sweep_tile` - How threads iterate over distributed data +- :ref:`ck_tile_gemm_optimization` - Real-world application of thread mapping in GEMM kernels +- :ref:`ck_tile_space_filling_curve` - Optimal traversal patterns for thread access diff --git a/docs/conceptual/ck_tile/tile_distribution.rst b/docs/conceptual/ck_tile/tile_distribution.rst new file mode 100644 index 0000000000..c57a87e5ce --- /dev/null +++ b/docs/conceptual/ck_tile/tile_distribution.rst @@ -0,0 +1,627 @@ +.. _ck_tile_distribution: + +Tile Distribution - The Core API +================================ + +Overview +-------- + +At the heart of Composable Kernel's approach to efficient GPU computation lies TileDistribution, a compile-time abstraction that transforms how developers approach parallel programming on GPUs. Rather than requiring programmers to manually manage thread coordination, memory access patterns, and data distribution, TileDistribution provides an mathematical framework that automatically maps logical computational coordinates to physical execution resources. + +The architectural foundation of tile distribution in CK rests upon the :ref:`coordinate transformation system ` that bridges multiple abstract spaces. This system manages the interaction between four primary coordinate dimensions, each serving a distinct purpose in the overall computation model. The X dimensions represent the physical tensor coordinates, capturing the actual layout of data in memory. The Y dimensions encode the tile access patterns, defining how threads traverse their assigned data. The P dimensions map to processing elements, representing the hierarchical organization of threads, warps, and blocks in the :ref:`GPU's execution model `. Additionally, the optional R dimensions enable replication strategies for algorithms that benefit from redundant computation to reduce communication overhead. + +This multi-dimensional mapping framework enables CK to express arbitrarily complex data access patterns through a mathematically formalism. The power of this approach becomes evident when considering how traditional GPU programming requires developers to manually calculate memory addresses, ensure coalescing constraints, :ref:`avoid bank conflicts `, and manage thread cooperation. TileDistribution handles all these concerns within a unified abstraction that can be analyzed, optimized, and verified at compile time. + +The ``tile_distribution`` template class integrates three essential components that work together to deliver optimal performance. The ``PsYs2XsAdaptor`` component performs :ref:`coordinate transformations ` from processing and pattern dimensions to physical tensor coordinates, implementing the mathematical mappings that ensure efficient memory access. The ``Ys2DDescriptor`` component handles the linearization of Y dimensions, transforming multi-dimensional tile patterns into register allocation schemes that maximize register reuse and minimize register pressure. The ``StaticTileDistributionEncoding`` captures the hierarchical decomposition of work across the GPU's compute resources, encoding decisions about how work is partitioned across thread blocks, warps, and individual threads. + +This design adapts to diverse computational scenarios without manual intervention. The same high-level code can execute on GPUs with different numbers of streaming multiprocessors, varying warp sizes, or distinct memory hierarchies. The compile-time nature of the abstraction ensures that all coordination logic is resolved during compilation, resulting in machine code that is comparable hand-optimized implementations. This adaptability enables a single implementation to achieve improved performance across a wide range of tensor sizes, shapes, and computational patterns without the combinatorial explosion of specialized kernels. + +Complete Tile Distribution System Overview +------------------------------------------ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Logical View" + T["Tensor
Multi-dimensional data"] + TD["TileDistribution
Work assignment"] + TW["TileWindow
Data view"] + end + + subgraph "Coordinate Spaces" + X["X: Physical tensor coords"] + Y["Y: Tile pattern coords"] + P["P: Processing element coords"] + R["R: Replication coords (optional)"] + end + + subgraph "GPU Execution" + W["Warps
32 threads each"] + L["Lanes
Thread within warp"] + REG["Registers
Thread-local storage"] + end + + T --> TD + TD --> TW + + TD --> X + TD --> Y + TD --> P + TD --> R + + P --> W + P --> L + TW --> REG + + style TD fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style P fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style REG fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_1.svg + :alt: Diagram + :align: center + +Coordinate System Architecture +------------------------------ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "Input" + TC["Thread Coordinates
(warpId, laneId)"] + end + + subgraph "Transformation Pipeline" + P2Y["P → Y
Thread to pattern"] + Y2X["Y → X
Pattern to physical"] + Y2D["Y → D
Pattern to register"] + end + + subgraph "Output" + MC["Memory Coordinates
Global addresses"] + RI["Register Indices
Local storage"] + end + + TC --> P2Y + P2Y --> Y2X + P2Y --> Y2D + Y2X --> MC + Y2D --> RI + + style TC fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style MC fill:#d1fae5,stroke:#10b981,stroke-width:2px + style RI fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_2.svg + :alt: Diagram + :align: center + +What is Tile Distribution? +-------------------------- + +In GPU programming, distributing work across thousands of parallel threads is an important challenge. Consider a 256×256 matrix multiplication operation and 64 GPU threads organized in warps. The question becomes how to divide this computational work in a way that maximizes memory bandwidth utilization, minimizes bank conflicts, and ensures coalesced memory accesses. + +The traditional approach without a tile distribution framework requires programmers to manually calculate global memory addresses for each thread, implement complex index arithmetic that accounts for thread hierarchy (threads within warps, warps within blocks), handle edge cases for non-divisible matrix dimensions, and create different implementations for various matrix sizes. This manual approach is not only error-prone but also fails to adapt to different GPU architectures and their specific memory access patterns. + +TileDistribution solves these challenges through a systematic approach to work distribution. It automatically assigns work to threads based on a hierarchical decomposition of the problem space, generates memory access patterns that respect GPU hardware constraints, provides a uniform interface that works across different tensor sizes and shapes, and ensures optimal thread cooperation by automatically managing data movement to thread-local registers. + +TileDistribution abstracts the mapping between logical problem coordinates and physical execution resources. Given a thread's position in the GPU's execution hierarchy (specified by warp ID and lane ID within the warp), TileDistribution computes two critical pieces of information: the global memory addresses that this thread should access, and the specific access pattern that ensures efficient memory transactions. This abstraction is implemented in C++ through the following core structure: + +.. code-block:: cpp + + template + struct tile_distribution + { + // Core functionality: map thread coordinates to data + CK_TILE_HOST_DEVICE static auto _get_partition_index() + { + if constexpr(NDimP == 1) + return array{get_lane_id()}; + else if constexpr(NDimP == 2) + return array{get_warp_id(), get_lane_id()}; + } + + // Calculate which tensor elements this thread accesses + template + CK_TILE_HOST_DEVICE static auto calculate_tile_Ys_index(const PartitionIndex& ps_idx) + { + return detail::calculate_tile_Ys_index( + StaticTileDistributionEncoding{}, ps_idx); + } + }; + +Problem Space Mapping +--------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + + graph TB + subgraph "Problem Space (256×256 Matrix)" + M["Full Matrix
65,536 elements"] + T1["Tile 1
32×32"] + T2["Tile 2
32×32"] + TN["Tile N
32×32"] + end + + subgraph "Thread Assignment" + W0["Warp 0
32 threads"] + W1["Warp 1
32 threads"] + L0["Lane 0-31
Individual threads"] + end + + subgraph "Memory Pattern" + MP["Coalesced Access
Sequential addresses
No bank conflicts"] + end + + M --> T1 + M --> T2 + M --> TN + + T1 --> W0 + T1 --> W1 + W0 --> L0 + L0 --> MP + + style M fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style MP fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_3.svg + :alt: Diagram + :align: center + +Creating a TileDistribution +--------------------------- + +Creating and using a TileDistribution: + +.. code-block:: cpp + + // SPDX-License-Identifier: MIT + // Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + + #include "ck_tile/host.hpp" + #include "ck_tile/core.hpp" + #include + #include + #include + + namespace ck_tile { + + struct TileDistributionExample + { + CK_TILE_DEVICE void operator()(float* global_data, + ck_tile::index_t global_shape_0, + ck_tile::index_t global_shape_1) const + { + if(threadIdx.x == 0 && blockIdx.x == 0) { + printf("\n=== Tile Distribution Example (Device Kernel) ===\n"); + } + block_sync_lds(); + + // Create a tile distribution encoding + // This defines how a tensor is distributed across threads + auto encoding = tile_distribution_encoding< + sequence<>, // rs_lengths=[] - No replication dimensions + tuple< + sequence<2, 2>, // hs_lengthss=[[2, 2], [2, 2]] - Hierarchical lengths for each X dimension + sequence<2, 2>>, + tuple, sequence<2>>, // ps_to_rhss_major=[[1], [2]] - P to RH major mappings + tuple, sequence<0>>, // ps_to_rhss_minor=[[0], [0]] - P to RH minor mappings + sequence<1, 2>, // ys_to_rhs_major=[1, 2] - Y to RH major mappings + sequence<1, 1>>{}; // ys_to_rhs_minor=[1, 1] - Y to RH minor mappings + + // Create the tile distribution from the encoding + auto distribution = make_static_tile_distribution(encoding); + + // Calculate sizes from the distribution encoding + // x0_size = np.prod(distribution.encoding.hs_lengthss[0]) + constexpr auto hs_lengths_0 = encoding.hs_lengthss_[number<0>{}]; // sequence<2, 2> + constexpr auto hs_lengths_1 = encoding.hs_lengthss_[number<1>{}]; // sequence<2, 2> + + constexpr index_t x0_size = reduce_on_sequence(hs_lengths_0, multiplies{}, number<1>{}); + constexpr index_t x1_size = reduce_on_sequence(hs_lengths_1, multiplies{}, number<1>{}); + + // Print distribution info (only from thread 0) + if(threadIdx.x == 0 && blockIdx.x == 0) { + printf("\n- Tile distribution created:\n"); + printf(" X dimensions: %d\n", distribution.get_num_of_dimension_x()); + printf(" Y dimensions: %d\n", distribution.get_num_of_dimension_y()); + printf(" P dimensions: %d\n", distribution.get_num_of_dimension_p()); + printf(" X lengths: [%d, %d]\n", x0_size, x1_size); + } + block_sync_lds(); + + // Create packed tensor view (contiguous row-major) using helper + auto global_view = make_naive_tensor_view_packed( + global_data, + make_tuple(global_shape_0, global_shape_1)); + + // Window configuration + auto window_lengths = make_tuple(x0_size, x1_size); + + // Get current thread's warp and thread indices + index_t warp_id = threadIdx.x / get_warp_size(); + index_t thread_id = threadIdx.x % get_warp_size(); + + // Window origin - small offset from origin + auto window_origin = make_tuple(1, 3); // Small offset from origin + + // Create tile window + auto tile_window = make_tile_window( + global_view, + window_lengths, + {1, 3}, // Window origin as initializer list + distribution + ); + + // Load distributed tensor + auto distributed_tensor = tile_window.load(); + + // Collect values by sweeping through the distributed tensor + constexpr index_t max_elements = x0_size*x1_size; + float collected_values[max_elements]; + index_t value_count = 0; + + // Sweep through the distributed tensor and collect values using sweep_tile API + sweep_tile(distributed_tensor, [&](auto idx) { + if(value_count(threadIdx.x) == sel) { + printf("Partition index: (warp=%d, thread=%d)\n", static_cast(warp_id), static_cast(thread_id)); + printf("Collected values: "); + for(index_t i = 0; i < value_count; i++) { + printf("%.0f", collected_values[i]); + if(i < value_count - 1) printf(", "); + } + printf("\n\n"); + } + block_sync_lds(); + } + } + }; + } + + int main() + { + // Host-side allocation & initialization of pattern data + // Reproduce the compile-time sizes used in the kernel: hs_lengths = [2,2] => x sizes=4; global = 4+5 = 9 + constexpr ck_tile::index_t global_shape_0 = 9; // x0_size(4) + 5 + constexpr ck_tile::index_t global_shape_1 = 9; // x1_size(4) + 5 + constexpr ck_tile::index_t total_elems = global_shape_0 * global_shape_1; // 81 + + std::vector h_global_data(total_elems); + for(ck_tile::index_t i = 0; i < global_shape_0; ++i) { + for(ck_tile::index_t j = 0; j < global_shape_1; ++j) { + h_global_data[i * global_shape_1 + j] = static_cast(i * 100 + j); + } + } + + ck_tile::DeviceMem d_global_data(sizeof(float) * total_elems); + d_global_data.ToDevice(h_global_data.data()); + + std::cout << "\nGlobal data (host print, to be used by device) shape=(" + << static_cast(global_shape_0) << "," << static_cast(global_shape_1) << ")\n\n"; + for(ck_tile::index_t i = 0; i < global_shape_0; ++i) { + for(ck_tile::index_t j = 0; j < global_shape_1; ++j) { + std::cout << h_global_data[i * global_shape_1 + j]; + if(j + 1 < global_shape_1) std::cout << "\t"; + } + std::cout << '\n'; + } + std::cout << '\n'; + + constexpr ck_tile::index_t kBlockSize = 128; + constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t kGridSize = 1; + + using Kernel = ck_tile::TileDistributionExample; + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_global_data.GetDeviceBuffer()), + global_shape_0, + global_shape_1)); + + std::cout << "Kernel execution completed. Average time: " << ave_time << " ms" << std::endl; + + return 0; + } + +Hierarchical Decomposition +-------------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Level 1: Block Distribution" + B["Thread Block
256 threads"] + BT1["Block Tile 1
64×64"] + BT2["Block Tile 2
64×64"] + end + + subgraph "Level 2: Warp Distribution" + W["Warp
32 threads"] + WT1["Warp Tile 1
16×16"] + WT2["Warp Tile 2
16×16"] + end + + subgraph "Level 3: Thread Distribution" + T["Thread"] + TT["Thread Tile
2×2"] + end + + B --> BT1 + BT1 --> W + W --> WT1 + WT1 --> T + T --> TT + + style B fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style W fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style T fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_4.svg + :alt: Diagram + :align: center + +Advanced Example: Matrix Multiplication Distribution +---------------------------------------------------- + +.. code-block:: cpp + + // Real GEMM kernel pattern using TileDistribution + template + __global__ void gemm_kernel( + const AType* __restrict__ a_ptr, + const BType* __restrict__ b_ptr, + CType* __restrict__ c_ptr, + index_t M, index_t N, index_t K) + { + // Define the tile distribution encoding at compile time + using Encoding = tile_distribution_encoding< + sequence<>, // R: no replication + tuple, // H for M dimension + sequence<4, 2, 8, 4>>, // H for N dimension + tuple, sequence<1, 2>>, // P to RH major + tuple, sequence<2, 2>>, // P to RH minor + sequence<1, 1, 2, 2>, // Y to RH major + sequence<0, 3, 0, 3> // Y to RH minor + >; + + // Create the distribution + constexpr auto distribution = make_static_tile_distribution(Encoding{}); + + // Create tensor views + auto a_view = make_tensor_view( + a_ptr, + make_naive_tensor_descriptor_packed(make_tuple(M, K))); + + // Create tile window for this thread block + auto a_window = make_tile_window( + a_view, + make_tuple(number<256>{}, number<64>{}), // window size + {blockIdx.x * 256, 0}, // origin + distribution); + + // Load data to distributed tensor (registers) + auto a_reg = make_static_distributed_tensor(distribution); + + a_window.load(a_reg); + + // Computation happens in registers + // Results written back through another window + } + +Work Distribution Pattern +------------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart TB + subgraph "Matrix C (128×128)" + C["16,384 elements"] + end + + subgraph "Thread Grid (32×32)" + TG["1,024 threads"] + end + + subgraph "Per Thread" + PT["4×4 tile
16 elements"] + end + + subgraph "Memory Access" + MA["Coalesced reads
Efficient writes
No conflicts"] + end + + C --> TG + TG --> PT + PT --> MA + + style C fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style TG fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style PT fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style MA fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_5.svg + :alt: Diagram + :align: center + +Memory Access Patterns +---------------------- + +One of the key benefits of TileDistribution is generating optimal memory access patterns. The encoding parameters control how threads access memory: + +- **H-dimensions**: Define hierarchical decomposition (Repeat, WarpPerBlock, ThreadPerWarp, Vector) +- **P-to-RH mappings**: Control how thread IDs map to the hierarchy +- **Y-to-RH mappings**: Define the access pattern within each thread's tile + +Transformation Pipeline +----------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Input" + TID["Thread ID
(0-1023)"] + end + + subgraph "Stage 1" + P["P-coordinates
(warp, lane)"] + end + + subgraph "Stage 2" + Y["Y-coordinates
(tile position)"] + end + + subgraph "Stage 3" + X["X-coordinates
(tensor indices)"] + end + + subgraph "Output" + ADDR["Memory addresses
Register indices"] + end + + TID --> P + P --> Y + Y --> X + X --> ADDR + + style TID fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + style ADDR fill:#d1fae5,stroke:#10b981,stroke-width:2px + + + + + +.. image:: diagrams/tile_distribution_6.svg + :alt: Diagram + :align: center + +Performance Comparison +---------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Manual Implementation" + M1["Calculate indices manually"] + M2["Handle boundary conditions"] + M3["Ensure coalescing"] + M4["Manage bank conflicts"] + M5["~200 lines of code"] + end + + subgraph "With TileDistribution" + T1["make_tile_distribution()"] + T2["Automatic optimization"] + T3["~10 lines of code"] + end + + subgraph "Performance" + P1["Same performance"] + P2["Fewer bugs"] + P3["Portable across GPUs"] + end + + M1 --> M5 + T1 --> T3 + + M5 --> P1 + T3 --> P1 + P1 --> P2 + P2 --> P3 + + style M5 fill:#fee2e2,stroke:#ef4444,stroke-width:2px + style T3 fill:#d1fae5,stroke:#10b981,stroke-width:2px + style P3 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px + + + + + + +.. image:: diagrams/tile_distribution_7.svg + :alt: Diagram + :align: center + +Summary +------- + +The automatic work distribution capabilities of TileDistribution eliminate one of the most error-prone aspects of GPU programming. TileDistribution's mathematical framework ensures that every thread knows which data elements it should process and automatically handles complex index arithmetic. + +Memory access pattern optimization is a performance benefit of the TileDistribution approach. GPUs achieve their computational throughput only when memory accesses follow specific patterns that enable hardware optimizations such as coalescing and broadcast. TileDistribution automatically generates these patterns, such that threads within a warp access contiguous memory locations, that bank conflicts in shared memory are reduced, and that the memory subsystem operates efficiently. This optimization happens transparently, without manual memory pattern analysis. + +By encoding the natural hierarchy of threads, warps, and blocks directly into the distribution strategy, the framework ensures that each level of the hierarchy operates optimally. This hierarchical approach enables tiling strategies that would be impractical to implement manually, such as multi-level tiling that simultaneously optimizes for L1 cache, L2 cache, and register file usage. + +The zero-overhead nature of TileDistribution, achieved through use of C++ template metaprogramming and compile-time computation, ensures that the abstraction's benefits come without runtime cost. Every aspect of the distribution strategy is resolved at compile time, resulting in machine code that is comparable to hand-written implementations. The compiler's ability to see through the abstraction enables optimizations that aren't typically available to runtime-based approaches. + +The same source code can execute efficiently on GPUs with different warp sizes, different numbers of registers per thread, or different shared memory capacities. This portability includes performance portability, with the framework adapting its strategies to match the characteristics of the target architecture. + +TileDistribution provides a solid foundation for the CK ecosystem. This abstraction provides a programming model that insulates developers from the complexity of the underlying hardware while enabling them to use hardware capabilities. + +Next Steps +---------- + +See :ref:`ck_tile_terminology` for a glossary of key concepts and terminology used in CK Tile. diff --git a/docs/conceptual/ck_tile/tile_window.rst b/docs/conceptual/ck_tile/tile_window.rst new file mode 100644 index 0000000000..87d2f39b01 --- /dev/null +++ b/docs/conceptual/ck_tile/tile_window.rst @@ -0,0 +1,701 @@ +.. _ck_tile_tile_window: + +Tile Window - Data Access Gateway +================================= + +Overview +-------- + +While :ref:`TileDistribution ` determines the mapping between threads and tensor coordinates, TileWindow provides the mechanism for loading and storing data with memory access patterns. This abstraction encapsulates coalesced memory accesses, vectorization, and boundary handling into an interface. + +TileWindow implements a distribution-aware windowing mechanism that views a subset of a larger tensor through the lens of a tile distribution. This windowing is a distribution-aware view that automatically generates memory access patterns for the underlying hardware. The system combines knowledge of the :ref:`tensor's layout `, the distribution pattern, and the :ref:`GPU's memory subsystem ` characteristics to generate optimized load and store operations. + +TileWindow Architecture +----------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Components" + TV["TensorView
Data source"] + TD["TileDistribution
Thread mapping"] + TW["TileWindow
Access gateway"] + LT["LoadStoreTraits
Access optimizer"] + DT["DistributedTensor
Register storage"] + end + + subgraph "Operations" + Load["Load
Global → Registers"] + Compute["Compute
In registers"] + Store["Store
Registers → Global"] + end + + subgraph "Optimizations" + Coal["Coalescing
Adjacent access"] + Vec["Vectorization
Multi-element ops"] + Bank["Bank conflict
avoidance"] + SFC["Space-filling
curve traversal"] + end + + TV --> TW + TD --> TW + TW --> LT + LT --> DT + + TW --> Load + Load --> Compute + Compute --> Store + + Load --> Coal + Load --> Vec + Load --> SFC + Store --> Bank + + style TW fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style LT fill:#fff3e0,stroke:#f57c00,stroke-width:2px + style DT fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + +.. image:: diagrams/tile_window_1.svg + :alt: Diagram + :align: center + +What is a TileWindow? +--------------------- + +The challenge in GPU programming lies in the gap between logical tensor operations and the physical realities of memory access. While :ref:`TileDistribution ` solves the problem of work assignment by mapping threads to :ref:`tensor coordinates `, it does not address how threads access the data at those coordinates. TileWindow serves as the critical bridge between logical work assignment and physical memory operations. + +TileWindow implements a distribution-aware windowing mechanism that transforms abstract coordinate mappings into concrete memory access patterns. The abstraction takes into account the data elements each thread needs and also how to access them in a way that maximizes memory bandwidth utilization. This involves optimized techniques such as memory coalescing, where adjacent threads access adjacent memory locations, and vectorization, where multiple elements are loaded or stored in a single transaction. + +**C++ Implementation Overview:** + +.. code-block:: cpp + + // From ck_tile/core/tensor/tile_window.hpp + #include + #include + #include + + template + struct tile_window_with_static_distribution + { + using TensorView = remove_cvref_t; + using Distribution = remove_cvref_t; + using DataType = typename TensorView::DataType; + + // Core components that define the window + TensorView tensor_view_; // View into the underlying tensor + Distribution distribution_; // How to distribute data across threads + array origin_; + + // Window-specific information + static constexpr auto window_lengths = WindowLengths{}; + static constexpr index_t num_of_dimension = TensorView::get_num_of_dimension(); + + // Constructor + CK_TILE_HOST_DEVICE constexpr tile_window_with_static_distribution( + const TensorView& tensor_view, + const WindowLengths& /*window_lengths*/, + const array& origin, + const Distribution& distribution) + : tensor_view_{tensor_view}, + distribution_{distribution}, + origin_{origin} + {} + + // Load operation with automatic coalescing + template + CK_TILE_DEVICE void load(DistributedTensor& dst_tensor) const + { + // Sophisticated load implementation that: + // 1. Calculates optimal access pattern + // 2. Handles vectorization automatically + // 3. Ensures coalesced memory access + // 4. Manages boundary conditions + } + }; + +LoadStoreTraits - The Access Pattern Engine +------------------------------------------- + +Behind every TileWindow operation lies :ref:`LoadStoreTraits `, a compile-time analysis engine that determines an optimized way to access memory. This component bridges the gap between the logical distribution pattern and the physical memory subsystem, analyzing the distribution to find opportunities for vectorization and coalescing. + +LoadStoreTraits performs several analyses: + +- **Vector dimension identification**: Finds which Y dimension has stride 1 for optimal vectorization +- **Access pattern calculation**: Determines the number and order of memory operations +- **Space-filling curve construction**: Creates an optimal traversal order for cache efficiency + +**C++ LoadStoreTraits Analysis:** + +.. code-block:: cpp + + // LoadStoreTraits analyzes the distribution pattern + template + struct load_store_traits + { + static constexpr index_t ndim_y = Distribution::ndim_y; + + // Analyze which Y dimension has stride 1 (best for vectorization) + static constexpr index_t vector_dim_y = []() { + // Complex compile-time analysis to find optimal dimension + return find_vector_dimension(); + }(); + + // Calculate vectorization potential + static constexpr index_t scalar_per_vector = []() { + // Determine how many elements can be loaded in one instruction + return calculate_vector_size(); + }(); + + // Space-filling curve for optimal traversal + using sfc_type = space_filling_curve; + static constexpr sfc_type sfc_ys = make_space_filling_curve(); + + // Get Y indices for a given access + CK_TILE_DEVICE constexpr auto get_y_indices(index_t i_access) const + { + return sfc_ys.get_index(i_access); + } + }; + +Space-Filling Curves for Memory Access +-------------------------------------- + +TileWindow uses :ref:`space-filling curves ` to determine the order in which memory is accessed. Space-filling curves provide cache-friendly traversal patterns that help maximize hardware utilization. The "snake" pattern minimizes the distance between consecutive accesses, keeping data in cache longer. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Linear Access Pattern" + L1["0,1,2,3"] + L2["4,5,6,7"] + L3["8,9,10,11"] + L4["12,13,14,15"] + end + + subgraph "Snake Access Pattern" + S1["0,1,2,3"] + S2["7,6,5,4"] + S3["8,9,10,11"] + S4["15,14,13,12"] + end + + L1 --> L2 + L2 --> L3 + L3 --> L4 + + S1 --> S2 + S2 --> S3 + S3 --> S4 + + style S1 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style S2 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + +.. image:: diagrams/tile_window_2.svg + :alt: Diagram + :align: center + +**C++ Space-Filling Curve Implementation:** + +.. code-block:: cpp + + // Space-filling curve for optimal memory traversal + template + struct space_filling_curve + { + array tensor_lengths; + array dim_access_order; + bool snake_curved; + + // Get coordinates for the i-th access + CK_TILE_DEVICE constexpr auto get_index(index_t i_access) const + { + array indices; + + // Snake pattern logic for cache-friendly access + if (snake_curved) { + // Implement snake curve traversal + // Minimizes distance between consecutive accesses + } + + return indices; + } + }; + +TileWindow Data Flow +-------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + flowchart LR + subgraph "Step 1: Create Window" + T["Tensor
[256, 256]"] + O["Origin
(64, 64)"] + W["Window Size
[32, 32]"] + end + + subgraph "Step 2: Apply Distribution" + TD["TileDistribution
Thread mapping"] + TW["TileWindow
Created"] + end + + subgraph "Step 3: Load Data" + GM["Global Memory
Window region"] + REG["Registers
Distributed tensor"] + end + + T --> TW + O --> TW + W --> TW + TD --> TW + + TW --> GM + GM -->|"load()"| REG + + style TW fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style REG fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + +.. image:: diagrams/tile_window_3.svg + :alt: Diagram + :align: center + +Creating and Using TileWindow +----------------------------- + +.. code-block:: cpp + + using namespace ck_tile; + + // Create a tensor view for input data (see :ref:`ck_tile_tensor_views`) + auto tensor_view = make_naive_tensor_view( + data_ptr, + make_tuple(256, 256), // Shape + make_tuple(256, 1) // Strides + ); + + // Define window parameters + constexpr auto window_size = make_tuple(32, 32); + auto window_origin = make_tuple(64, 64); + + // Create distribution for the window + auto distribution = make_static_tile_distribution< + tile_distribution_encoding< + sequence<>, // No replication + tuple, sequence<4, 2>>, // 8x8 threads + tuple, sequence<1>>, // Thread mapping + tuple, sequence<0>>, // Minor indices + sequence<1, 1>, // Y-space: 2x2 per thread + sequence<1, 1> // Y-space minor + > + >{}; + + // Create the tile window + auto window = make_tile_window( + tensor_view, + window_size, + window_origin, + distribution + ); + + // Load data into distributed tensor (see :ref:`ck_tile_static_distributed_tensor`) + auto distributed_data = make_static_distributed_tensor(distribution); + window.load(distributed_data); + +The Load Operation Deep Dive +---------------------------- + +Calls to ``window.load()`` trigger the following sequence of operations: + +1. **Distributed tensor creation**: Automatically creates a :ref:`distributed tensor ` sized for the distribution +2. **Coordinate calculation**: Uses precomputed coordinates for efficiency +3. **Vectorized access**: Groups elements for vector loads based on :ref:`LoadStoreTraits ` analysis +4. **Memory coalescing**: Ensures adjacent threads access adjacent memory +5. **Boundary handling**: Manages edge cases automatically + +**C++ Load Implementation Details:** + +.. code-block:: cpp + + template + CK_TILE_DEVICE void load(DistributedTensor& dst_tensor) const + { + // Get LoadStoreTraits for optimal access pattern + using Traits = load_store_traits; + + // Iterate through all accesses determined by space-filling curve + static_for<0, Traits::num_access, 1>{}([&](auto i_access) { + // Get Y-space indices for this access + const auto y_indices = Traits::get_y_indices(i_access); + + // Calculate global coordinates + const auto x_indices = distribution_.calculate_x_from_y(y_indices); + const auto global_indices = add_arrays(origin_, x_indices); + + // Perform vectorized load if possible + if constexpr (Traits::scalar_per_vector > 1) { + // Vector load path + using VectorType = vector_type_t; + const auto vector_data = tensor_view_.template get_vectorized_elements( + global_indices, Traits::vector_dim_y); + dst_tensor.template set_vectorized_elements(y_indices, vector_data); + } else { + // Scalar load path + const auto scalar_data = tensor_view_.get_element(global_indices); + dst_tensor.set_element(y_indices, scalar_data); + } + }); + } + +Load Operation Architecture +--------------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Load Analysis" + Analyze["Analyze access pattern
Detect coalescing opportunities"] + end + + subgraph "Vectorization" + V1["Scalar: 4 loads"] + V2["Vector2: 2 loads"] + V4["Vector4: 1 load"] + end + + subgraph "Memory Transaction" + Coal["Coalesced access
32 threads → 1 transaction"] + NonCoal["Non-coalesced
32 threads → 32 transactions"] + end + + subgraph "Result" + Reg["Thread registers
Local data"] + end + + Analyze --> V1 + Analyze --> V2 + Analyze --> V4 + + V4 --> Coal + V1 --> NonCoal + + Coal --> Reg + NonCoal --> Reg + + style V4 fill:#d1fae5,stroke:#10b981,stroke-width:2px + style Coal fill:#d1fae5,stroke:#10b981,stroke-width:2px + style NonCoal fill:#fee2e2,stroke:#ef4444,stroke-width:2px + + + +.. image:: diagrams/tile_window_4.svg + :alt: Diagram + :align: center + +Memory Access Patterns +---------------------- + +One of TileWindow's key features is generating optimal memory access patterns. The system analyzes the distribution to ensure: + +- **Coalescing**: Adjacent threads access adjacent memory locations +- **Vectorization**: Multiple elements loaded in single instructions +- **Bank conflict avoidance**: Shared memory accesses avoid :ref:`conflicts ` +- **Cache optimization**: Access patterns maximize cache reuse + +**C++ Memory Pattern Analysis:** + +.. code-block:: cpp + + // Analyze memory access pattern for a distribution + template + struct memory_access_analyzer + { + static constexpr bool is_coalesced() + { + // Check if threads in a warp access consecutive memory + return Distribution::check_coalescing_pattern(); + } + + static constexpr index_t vector_size() + { + // Determine optimal vector size (1, 2, 4, 8) + return Distribution::calculate_vector_size(); + } + + static constexpr bool has_bank_conflicts() + { + // Analyze shared memory access pattern + return Distribution::detect_bank_conflicts(); + } + }; + +Window Movement and Updates +--------------------------- + +TileWindow supports efficient window movement for sliding window algorithms. The precomputed coordinate system makes updates more efficient: + +.. code-block:: cpp + + // Sliding window pattern + for (index_t row = 0; row < tensor_height; row += stride) { + for (index_t col = 0; col < tensor_width; col += stride) { + // Update window position - O(1) operation + window.set_window_origin(make_tuple(row, col)); + + // Load from new position + window.load(distributed_data); + + // Process data + process_tile(distributed_data); + + // Store results if needed + output_window.store(distributed_data); + } + } + +Store Operations with Vectorization +----------------------------------- + +Store operations use the same compile-time analysis as loads. The :ref:`LoadStoreTraits ` helps make stores as efficient as loads, with similar vectorization and coalescing benefits: + +.. code-block:: cpp + + template + CK_TILE_DEVICE void store(const DistributedTensor& src_tensor) const + { + using Traits = load_store_traits; + + // Same optimized pattern as load, but in reverse + static_for<0, Traits::num_access, 1>{}([&](auto i_access) { + const auto y_indices = Traits::get_y_indices(i_access); + const auto x_indices = distribution_.calculate_x_from_y(y_indices); + const auto global_indices = add_arrays(origin_, x_indices); + + if constexpr (Traits::scalar_per_vector > 1) { + // Vectorized store + const auto vector_data = src_tensor.template get_vectorized_elements( + y_indices, Traits::vector_dim_y); + tensor_view_.template set_vectorized_elements( + global_indices, vector_data, Traits::vector_dim_y); + } else { + // Scalar store + const auto scalar_data = src_tensor.get_element(y_indices); + tensor_view_.set_element(global_indices, scalar_data); + } + }); + } + +Complete Load-Compute-Store Pipeline +------------------------------------ + +.. code-block:: cpp + + template + __global__ void gemm_kernel_with_windows( + const AType* __restrict__ a_ptr, + const BType* __restrict__ b_ptr, + CType* __restrict__ c_ptr, + index_t M, index_t N, index_t K) + { + // Create tensor views + auto a_tensor = make_naive_tensor_view( + a_ptr, make_tuple(M, K), make_tuple(K, 1)); + auto b_tensor = make_naive_tensor_view( + b_ptr, make_tuple(K, N), make_tuple(N, 1)); + auto c_tensor = make_naive_tensor_view( + c_ptr, make_tuple(M, N), make_tuple(N, 1)); + + // Define tile sizes + constexpr index_t tile_m = 32; + constexpr index_t tile_n = 32; + constexpr index_t tile_k = 8; + + // Create distributions + auto a_dist = make_static_tile_distribution<...>(); + auto b_dist = make_static_tile_distribution<...>(); + auto c_dist = make_static_tile_distribution<...>(); + + // Calculate tile position + const index_t block_m = blockIdx.y * tile_m; + const index_t block_n = blockIdx.x * tile_n; + + // Create tile windows + auto a_window = make_tile_window( + a_tensor, + make_tuple(tile_m, tile_k), + make_tuple(block_m, 0), + a_dist); + + auto b_window = make_tile_window( + b_tensor, + make_tuple(tile_k, tile_n), + make_tuple(0, block_n), + b_dist); + + auto c_window = make_tile_window( + c_tensor, + make_tuple(tile_m, tile_n), + make_tuple(block_m, block_n), + c_dist); + + // Create distributed tensors for register storage + // See :ref:`ck_tile_static_distributed_tensor` for details + auto a_reg = make_static_distributed_tensor(a_dist); + auto b_reg = make_static_distributed_tensor(b_dist); + auto c_reg = make_static_distributed_tensor(c_dist); + + // Initialize accumulator + c_reg.clear(); + + // Main GEMM loop + for(index_t k = 0; k < K; k += tile_k) { + // Update window positions + a_window.set_window_origin(make_tuple(block_m, k)); + b_window.set_window_origin(make_tuple(k, block_n)); + + // Load tiles - LoadStoreTraits ensures optimal pattern + a_window.load(a_reg); + b_window.load(b_reg); + + // Compute + gemm(a_reg, b_reg, c_reg); + } + + // Store result - using same optimized pattern + c_window.store(c_reg); + } + +Performance Characteristics +--------------------------- + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph LR + subgraph "Memory Access Optimization" + V["Vectorization
4x fewer transactions"] + C["Coalescing
32x bandwidth efficiency"] + P["Precomputation
Zero overhead addressing"] + S["Space-filling
Optimal cache usage"] + end + + subgraph "Hardware Utilization" + BW["Memory Bandwidth
Near 100% utilization"] + L["Latency Hiding
Overlapped operations"] + R["Register Reuse
Minimal spills"] + end + + V --> BW + C --> BW + P --> L + S --> R + + style V fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + style C fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + style BW fill:#d1fae5,stroke:#10b981,stroke-width:3px + + + +.. image:: diagrams/tile_window_5.svg + :alt: Diagram + :align: center +Best Practices +-------------- + +Window Size Selection +~~~~~~~~~~~~~~~~~~~~~ + +Choose window sizes that balance register usage with data reuse: + +.. code-block:: cpp + + // Optimal window size calculation + template + constexpr auto calculate_optimal_window_size() + { + // Consider register constraints + constexpr index_t elements_per_thread = RegistersPerThread / sizeof(DataType); + + // Common tile sizes that work well + constexpr array common_sizes = {8, 16, 32, 64, 128}; + + // Find largest size that fits in registers + for (auto size : common_sizes) { + if (size * size <= elements_per_thread) { + return size; + } + } + return 8; // Minimum reasonable size + } + +Access Pattern Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Design distributions for optimal memory access: + +.. code-block:: cpp + + // Create distribution optimized for coalescing + // This example shows a 32x32 tile distributed across threads + using OptimalDistribution = tile_distribution_encoding< + sequence<>, // RsLengths: No replication + tuple, sequence<4, 8>>, // HsLengthss: Hierarchical decomposition + tuple, sequence<1, 2>>, // Ps2RHssMajor: P to RH major mapping + tuple, sequence<1, 0>>, // Ps2RHssMinor: P to RH minor mapping + sequence<1, 1, 2, 2>, // Ys2RHsMajor: Y to RH major mapping + sequence<0, 1, 0, 1> // Ys2RHsMinor: Y to RH minor mapping + >; + +Summary +------- + +TileWindow provides: + +- **Automatic optimization**: Generates optimal memory access patterns through :ref:`LoadStoreTraits ` +- **Distribution awareness**: Works seamlessly with :ref:`TileDistribution ` +- **Space-filling curves**: Optimizes traversal order for cache efficiency (see :ref:`ck_tile_space_filling_curve`) +- **Vectorization**: Automatic multi-element operations +- **Precomputation**: Zero-overhead :ref:`coordinate transformations ` +- **Flexible windowing**: Supports various access patterns and window configurations +- **Safety**: Automatic boundary handling + +Key benefits: + +1. **Performance**: Improves memory bandwidth through coalescing and vectorization +2. **Productivity**: Reduces reliance manual memory management code +3. **Correctness**: Automatic boundary checking and handling +4. **Composability**: Integrates with other CK abstractions +5. **Intelligence**: LoadStoreTraits analyzes and optimizes access + +The TileWindow abstraction helps build high-performance GPU kernels, providing an interface for complex memory access patterns while helping maintain performance gains. The compile-time analysis performed by LoadStoreTraits ensures that memory operations are as efficient as possible, while the space-filling curve traversal maximizes cache utilization. + +Next Steps +---------- + + +- :ref:`ck_tile_load_store_traits` - Deep dive into access pattern optimization +- :ref:`ck_tile_space_filling_curve` - Advanced traversal patterns +- :ref:`ck_tile_static_distributed_tensor` - Register-based tensor storage +- :ref:`ck_tile_lds_index_swapping` - Advanced shared memory optimization +- :ref:`ck_tile_sweep_tile` - Efficient tile-based algorithms diff --git a/docs/conceptual/ck_tile/transforms.rst b/docs/conceptual/ck_tile/transforms.rst new file mode 100644 index 0000000000..63b830563e --- /dev/null +++ b/docs/conceptual/ck_tile/transforms.rst @@ -0,0 +1,769 @@ +.. _ck_tile_transforms: + +Individual Transform Operations +=============================== + +The transformation engine is built from individual transform types that each handle specific coordinate conversions. + +What Are Transforms? +-------------------- + +Transform operations convert coordinates between different dimensional spaces. Each transform operates between two :ref:`coordinate spaces `: + +- **Lower Dimension Space**: The source coordinate system +- **Upper Dimension Space**: The target coordinate system + +Transform Direction +~~~~~~~~~~~~~~~~~~~ + +Transforms work bidirectionally: + +- **Forward Transform**: Converts coordinates from the lower dimension to the upper dimension +- **Inverse Transform**: Converts coordinates back from the upper dimension to the lower dimension + +Zero-Copy Logical Operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Critical Understanding**: All transform operations happen in **logical coordinate space** only. This is a zero-copy system and there is **no data copying or movement** involved. + +- **Data Storage**: The actual tensor data remains stored in memory in linear fashion, exactly as specified by the original tensor shape and strides at creation time. See :ref:`ck_tile_buffer_views` for more information about raw memory access. +- **Logical Mapping**: Transforms create different logical views of the same underlying data and only change how access coordinates are interpreted. See :ref:`ck_tile_tensor_views` for more information about tensor views. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "Tensor Coordinate Transformation" + US["Lower Dimension Space
Source coordinate system"] + LS["Upper Dimension Space
Target coordinate system"] + + DATA["Linear Data in Memory
Layout determined by tensor
shape & strides"] + end + + US -->|"Forward Transform"| LS + LS -->|"Inverse Transform"| US + + DATA -.->|"Same data,
different views"| US + DATA -.->|"Same data,
different views"| LS + + style US fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style LS fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_1.svg + :alt: Diagram + :align: center + +Index Calculation Operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The transform system provides two operations for coordinate conversion: + +- **calculate_lower_index()**: Takes a coordinate from the **upper dimension space** and transforms it to get the corresponding index or coordinate in the **lower dimension space**. This calculates where to find the actual tensor element using the transformed coordinate system. + +- **calculate_upper_index()**: Takes a coordinate from the **lower dimension space** and transforms it back to get the corresponding coordinate in the **upper dimension space**. This performs the inverse transformation to recover the original coordinate representation. + +These operations enable bidirectional navigation between different coordinate representations of the same underlying tensor data. + +Transform System Architecture +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + + subgraph "Transform Types" + EMB["EmbedTransform
Linear → Multi-D Strided"] + UNM["MergeTransform
Multi-D → Linear"] + MRG["UnmergeTransform
Linear → Multi-D"] + REP["ReplicateTransform
0D → Multi-D Broadcast"] + OFF["OffsetTransform
Translation"] + PAS["PassThroughTransform
Identity"] + PAD["PadTransform
Boundaries"] + end + + subgraph "Operations" + FWD["Forward
calculate_lower_index()"] + BWD["Backward
calculate_upper_index()"] + UPD["Update
update_lower_index()"] + end + + EMB --> FWD + UNM --> FWD + MRG --> FWD + REP --> FWD + OFF --> FWD + PAS --> FWD + PAD --> FWD + + style FWD fill:#e8f5e9,stroke:#388e3c,stroke-width:2px + + + + + +.. image:: diagrams/transforms_2.svg + :alt: Diagram + :align: center + +MergeTransform +-------------- + +MergeTransform collapses multiple dimensions from the lower coordinate space into a single dimension in the upper coordinate space, effectively reducing the dimensionality of the tensor representation while preserving data relationships. This transform is fundamental to the :ref:`tile distribution system `. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "MergeTransform: Multi-D → Linear" + LS["Lower Coordinate Space
2D: [4, 5]
Coord: (2, 3)"] + US["Upper Coordinate Space
1D Linear
Index: 13"] + + DATA["Same Tensor Data
Layout: row-major
Size: 20 elements"] + end + + LS -->|"Forward Transform
2×5 + 3 = 13"| US + US -->|"Inverse Transform
13÷5=2, 13%5=3"| LS + + DATA -.->|"Multi-dimensional
view"| LS + DATA -.->|"Linear
view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_3.svg + :alt: Diagram + :align: center + + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create MergeTransform for 4x5 tensor (20 elements total) + auto transform = make_merge_transform(make_tuple(4, 5)); + + // Forward: Lower (2D) → Upper (1D) - Manual calculation + int row = 2, col = 3; + int linear_index = row * 5 + col; // Result: 13 + printf("2D coord [%d, %d] → Linear index %d\n", row, col, linear_index); + printf("Calculation: %d×5 + %d = %d\n", row, col, linear_index); + + // Inverse: Upper (1D) → Lower (2D) - Using transform + multi_index<1> upper_coord; + upper_coord[number<0>{}] = 13; + + multi_index<2> lower_coord; + transform.calculate_lower_index(lower_coord, upper_coord); + + printf("Linear index %d → 2D coord [%d, %d]\n", + static_cast(upper_coord[number<0>{}]), + static_cast(lower_coord[number<0>{}]), + static_cast(lower_coord[number<1>{}])); + printf("Calculation: 13 ÷ 5 = %d remainder %d\n", + static_cast(lower_coord[number<0>{}]), + static_cast(lower_coord[number<1>{}])); + +UnmergeTransform +---------------- + +UnmergeTransform expands coordinates from a single dimension in the lower coordinate space into multiple dimensions in the upper coordinate space, effectively increasing the dimensionality of the tensor representation while preserving all data relationships. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "UnmergeTransform: Linear → Multi-D" + LS["Lower Coordinate Space
1D Linear
Index: 14"] + US["Upper Coordinate Space
3D: [3, 4, 2]
Coord: (1, 3, 0)"] + + DATA["Same Tensor Data
Layout: row-major
Size: 24 elements"] + end + + LS -->|"Forward Transform
14 = 1×8 + 3×2 + 0"| US + US -->|"Inverse Transform
linearize back"| LS + + DATA -.->|"Linear
view"| LS + DATA -.->|"Multi-dimensional
view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_4.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create UnmergeTransform for 3x4x2 tensor (24 elements total) + auto transform = make_unmerge_transform(make_tuple(3, 4, 2)); + + // Forward: Lower (1D) → Upper (3D) - Manual calculation + int linear_index = 14; + int dim0 = linear_index / (4 * 2); // 14 / 8 = 1 + int remainder = linear_index % (4 * 2); // 14 % 8 = 6 + int dim1 = remainder / 2; // 6 / 2 = 3 + int dim2 = remainder % 2; // 6 % 2 = 0 + + printf("Linear index %d → 3D coord [%d, %d, %d]\n", + linear_index, dim0, dim1, dim2); + printf("Calculation: 14 = %d×8 + %d×2 + %d\n", dim0, dim1, dim2); + + // Inverse: Upper (3D) → Lower (1D) - Using transform + multi_index<3> upper_coord; + upper_coord[number<0>{}] = 1; + upper_coord[number<1>{}] = 3; + upper_coord[number<2>{}] = 0; + + multi_index<1> lower_coord; + transform.calculate_lower_index(lower_coord, upper_coord); + + printf("3D coord [%d, %d, %d] → Linear index %d\n", + static_cast(upper_coord[number<0>{}]), + static_cast(upper_coord[number<1>{}]), + static_cast(upper_coord[number<2>{}]), + static_cast(lower_coord[number<0>{}])); + printf("Calculation: %d×8 + %d×2 + %d = %d\n", + static_cast(upper_coord[number<0>{}]), + static_cast(upper_coord[number<1>{}]), + static_cast(upper_coord[number<2>{}]), + static_cast(lower_coord[number<0>{}])); + +EmbedTransform +-------------- + +EmbedTransform expands linear indices from the lower coordinate space into multi-dimensional coordinates in the upper coordinate space using configurable strides, enabling flexible strided tensor layouts and sub-tensor views within larger buffers. + + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "EmbedTransform: Linear → Multi-D Strided" + LS["Lower Coordinate Space
1D Linear
Index: 14"] + US["Upper Coordinate Space
2D: [2, 3]
Coord: (1, 2)"] + + DATA["Linear Buffer in Memory"] + end + + LS -->|"Forward Transform
Strides: [12, 1]
14 ÷ 12 = 1, 14 % 12 = 2"| US + US -->|"Inverse Transform
1×12 + 2×1 = 14"| LS + + DATA -.->|"Linear
index view"| LS + DATA -.->|"Multi-dimensional
strided view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_5.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create embed transform for 2x3 tensor with strides [12, 1] + // This is commonly used in :ref:`descriptors ` + auto transform = make_embed_transform(make_tuple(2, 3), make_tuple(12, 1)); + + // Forward: Linear → 2D (Manual calculation) + int linear_idx = 14; + int row = linear_idx / 12; // 14 / 12 = 1 + int remainder = linear_idx % 12; // 14 % 12 = 2 + int col = remainder / 1; // 2 / 1 = 2 + printf("Linear index %d → 2D coord [%d, %d]\n", linear_idx, row, col); + + // Inverse: 2D → Linear (Using transform) + multi_index<2> upper_coord; + upper_coord[number<0>{}] = 1; + upper_coord[number<1>{}] = 2; + + multi_index<1> lower_coord; + transform.calculate_lower_index(lower_coord, upper_coord); + printf("2D coord [%d, %d] → Linear index %d\n", + static_cast(upper_coord[number<0>{}]), + static_cast(upper_coord[number<1>{}]), + static_cast(lower_coord[number<0>{}])); + +ReplicateTransform +------------------ + +ReplicateTransform creates a higher-dimensional tensor by replicating (broadcasting) a lower-dimensional tensor. It's essentially a broadcasting operation that takes a tensor with fewer dimensions and logically replicates it across new dimensions without data duplication. An example is taking a scalar (0-dimensional) input and broadcasting it across multiple dimensions, enabling efficient broadcasting patterns where a single value appears at every position in a multi-dimensional coordinate space. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "ReplicateTransform: 0D → Multi-D Broadcasting" + LS["Lower Coordinate Space
0D: Scalar
Empty coordinate []"] + US["Upper Coordinate Space
2D: [3, 4]
All coords: (i, j)"] + + DATA["Single Scalar Value"] + end + + LS -->|"Forward Transform
[] → (i,j) for any i,j"| US + US -->|"Inverse Transform
(i,j) → [] for any i,j"| LS + + DATA -.->|"One scalar
value"| LS + DATA -.->|"Broadcasted view
at all positions"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_6.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create replicate transform for 3x4 broadcasting + auto transform = make_replicate_transform(make_tuple(3, 4)); + + // Inverse: Upper (2D) → Lower (0D) - Using transform + // Any 2D coordinate maps to empty scalar coordinate + multi_index<2> upper_coord; + upper_coord[number<0>{}] = 1; + upper_coord[number<1>{}] = 2; + + multi_index<0> lower_coord; // Empty coordinate (0 dimensions) + transform.calculate_lower_index(lower_coord, upper_coord); + printf("2D [%d, %d] → Empty scalar [] (always empty)\n", + static_cast(upper_coord[number<0>{}]), + static_cast(upper_coord[number<1>{}])); + + // Forward: Scalar → 2D (Conceptual - no coordinate calculation needed) + // Broadcasting: Single scalar value appears at ALL positions + printf("Broadcasting: Scalar value appears at every [i,j] where 0≤i<3, 0≤j<4\n"); + printf("Total positions: 3×4 = 12 positions, all contain same scalar value\n"); + + // Test multiple coordinates - all map to empty scalar + int test_coords[][2] = {{0, 0}, {1, 2}, {2, 3}}; + for(int i = 0; i < 3; i++) + { + multi_index<2> test_upper; + test_upper[number<0>{}] = test_coords[i][0]; + test_upper[number<1>{}] = test_coords[i][1]; + + multi_index<0> test_lower; + transform.calculate_lower_index(test_lower, test_upper); + printf("2D [%d, %d] → Empty scalar []\n", + test_coords[i][0], test_coords[i][1]); + } + +OffsetTransform +--------------- + +OffsetTransform shifts coordinates by a fixed offset, creating a translated view of the coordinate space. It performs translation operations where each coordinate in the upper space is mapped to a coordinate in the lower space by adding a constant offset. + + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "OffsetTransform: 1D → 1D Translation" + LS["Lower Coordinate Space
1D: [0, 63]
Coord: index + offset"] + US["Upper Coordinate Space
1D: [0, 47]
Coord: index"] + + DATA["Linear Buffer in Memory"] + end + + LS -->|"Forward Transform
idx → idx + 16"| US + US -->|"Inverse Transform
idx + 16 → idx"| LS + + DATA -.->|"Lower
view"| LS + DATA -.->|"Upper
view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_7.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Create offset transform for coordinate translation + // CK Tile formula: lower = upper + offset + auto transform = make_offset_transform(48, 16); + + // Using Transform: Original → Translated (adds offset) + multi_index<1> upper_coord; + upper_coord[number<0>{}] = 5; // Original index 5 + + multi_index<1> lower_coord; + transform.calculate_lower_index(lower_coord, upper_coord); + printf("Original index %d → Translated index %d\n", + static_cast(upper_coord[number<0>{}]), + static_cast(lower_coord[number<0>{}])); + printf("Calculation: %d + 16 = %d\n", + static_cast(upper_coord[number<0>{}]), + static_cast(lower_coord[number<0>{}])); + + // Manual Reverse: Translated → Original (subtract offset) + int translated_idx = 21; + int original_idx = translated_idx - 16; + printf("Translated index %d → Original index %d\n", translated_idx, original_idx); + + // Test multiple coordinates + int test_indices[] = {0, 10, 20, 47}; + for(int i = 0; i < 4; i++) + { + multi_index<1> test_upper; + test_upper[number<0>{}] = test_indices[i]; + + multi_index<1> test_lower; + transform.calculate_lower_index(test_lower, test_upper); + printf("Original %d → Translated %d\n", + test_indices[i], static_cast(test_lower[number<0>{}])); + } + +PassThroughTransform - Identity +------------------------------- + +No-op transform that passes coordinates unchanged. The PassThrough transform is the simplest coordinate transformation in CK Tile, implementing a perfect identity mapping where input coordinates are passed through unchanged to the output. This transform is essential as a placeholder in transformation chains and for dimensions that require no modification. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "PassThroughTransform: 1D → 1D Identity" + LS["Lower Coordinate Space
1D: [0, 59]
Coord: index"] + US["Upper Coordinate Space
1D: [0, 59]
Coord: index"] + + DATA["Linear Buffer in Memory"] + end + + LS -.->|"Perfect Identity
idx → idx"| US + US -.->|"Perfect Identity
idx → idx"| LS + + DATA -->|"Same buffer
same view"| LS + DATA -->|"Same buffer
same view"| US + + style LS fill:#e8f5e8,stroke:#2e7d32,stroke-width:3px + style US fill:#e8f5e8,stroke:#2e7d32,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_8.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // Identity transform - no change + int length = 60; + + auto transform = make_pass_through_transform(length); + + printf("Length: %d\n", length); + + // Forward: Upper → Lower (identity) + multi_index<1> upper_coord; + upper_coord[number<0>{}] = 25; + + multi_index<1> lower_coord; + transform.calculate_lower_index(lower_coord, upper_coord); + + printf("\nForward: [%d] → [%d] (unchanged)\n", + static_cast(upper_coord[number<0>{}]), + static_cast(lower_coord[number<0>{}])); + + // Reverse: Lower → Upper (identity) + multi_index<1> lower_input; + lower_input[number<0>{}] = 42; + + multi_index<1> upper_result; + // Note: PassThrough is bidirectional identity, so we can use same function + transform.calculate_lower_index(upper_result, lower_input); + + printf("Reverse: [%d] → [%d] (unchanged)\n", + static_cast(lower_input[number<0>{}]), + static_cast(upper_result[number<0>{}])); + +PadTransform +------------ + +PadTransform adds padding to tensor dimensions, mapping coordinates from upper dimension space (with padding) to lower dimension space (original data). + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "PadTransform: 1D → 1D with Padding" + LS["Lower Coordinate Space
1D: [0, 2] (original data)"] + US["Upper Coordinate Space
1D: [0, 4] (with padding)"] + + DATA["Tensor Data in Memory"] + end + + LS -->|"Forward Transform
idx + left_pad"| US + US -->|"Inverse Transform
idx - left_pad"| LS + + DATA -.->|"Original view"| LS + DATA -.->|"Padded view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_9.svg + :alt: Diagram + :align: center + +**C++ Implementation:** + +.. code-block:: cpp + + using namespace ck_tile; + + // PadTransform for coordinate padding + int low_length = 3; // Original dimension length + int left_pad = 1; // Padding on left + int right_pad = 1; // Padding on right + + auto transform = make_pad_transform(low_length, left_pad, right_pad); + + printf("Low length: %d\n", low_length); + printf("Left pad: %d\n", left_pad); + printf("Right pad: %d\n", right_pad); + printf("Upper length: %d (total with padding)\n", low_length + left_pad + right_pad); + + // Test coordinate mapping + int test_coords[] = {0, 1, 2, 3, 4}; + for(int i = 0; i < 5; i++) + { + multi_index<1> upper; + upper[number<0>{}] = test_coords[i]; + + multi_index<1> lower; + transform.calculate_lower_index(lower, upper); + + int adjusted_idx = static_cast(lower[number<0>{}]); + bool is_valid = (adjusted_idx >= 0 && adjusted_idx < low_length); + + printf("Upper %d → Lower %d (%s)\n", + test_coords[i], adjusted_idx, + is_valid ? "valid" : "padding"); + } + +Additional Transform Types +-------------------------- + +XorTransform +~~~~~~~~~~~~ + +XorTransform applies a 2D XOR mapping for specialized memory access patterns. It performs XOR operations on coordinates to create transformed memory layouts for specific algorithmic optimizations, particularly useful for avoiding :ref:`LDS bank conflicts `. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "XorTransform: 2D → 2D XOR Mapping" + LS["Lower Coordinate Space
2D: [4, 8]
XOR-transformed coords"] + US["Upper Coordinate Space
2D: [4, 8]
Normal coords"] + + DATA["Same Tensor Data"] + end + + LS -->|"Forward Transform
apply XOR reverse"| US + US -->|"Inverse Transform
apply XOR mapping"| LS + + DATA -.->|"XOR pattern
view"| LS + DATA -.->|"Normal
view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_10.svg + :alt: Diagram + :align: center + +SliceTransform +~~~~~~~~~~~~~~ + +SliceTransform extracts a sub-region from a tensor dimension. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "SliceTransform: 1D → 1D Sub-region" + LS["Lower Coordinate Space
1D: [0, 9] (original range)"] + US["Upper Coordinate Space
1D: [0, 4] (slice range)"] + + DATA["Tensor Data in Memory"] + end + + LS -->|"Forward Transform
idx + slice_begin"| US + US -->|"Inverse Transform
idx - slice_begin"| LS + + DATA -.->|"Full tensor
view"| LS + DATA -.->|"Sub-region
view"| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + + + +.. image:: diagrams/transforms_11.svg + :alt: Diagram + :align: center + +ModuloTransform +~~~~~~~~~~~~~~~ + +ModuloTransform applies cyclic wrapping to coordinates using modulo operations. + +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + + .. mermaid:: + + graph TB + subgraph "ModuloTransform: 1D → 1D Cyclic" + LS["Lower Coordinate Space
1D: [0, 3] (modulus range)"] + US["Upper Coordinate Space
1D: [0, 15] (full range)"] + + DATA["Tensor Data in Memory"] + end + + LS -->|"Forward Transform
idx * cycle_count"| US + US -->|"Inverse Transform
idx % modulus"| LS + + DATA -.->|" "| LS + DATA -.->|" "| US + + style LS fill:#e3f2fd,stroke:#1976d2,stroke-width:3px + style US fill:#fff3e0,stroke:#f57c00,stroke-width:3px + style DATA fill:#f0f9ff,stroke:#0284c7,stroke-width:2px,stroke-dasharray: 5 5 + + + +.. image:: diagrams/transforms_12.svg + :alt: Diagram + :align: center + +Summary +------- + +Individual transforms provide: + +- **Modularity**: Each transform does one thing +- **Composability**: Chain transforms for complex mappings (see :ref:`ck_tile_adaptors`) +- **Efficiency**: Compile-time optimization in C++ +- **Flexibility**: Handle any coordinate conversion need + +These transforms enable you to: + +1. Create custom tensor views +2. Implement efficient data access patterns +3. Handle padding and boundaries correctly +4. Optimize memory layouts for :ref:`GPU access ` + +The C++ implementations in Composable Kernel provide: + +- Zero-overhead abstractions through templates +- Compile-time composition and optimization +- Support for complex coordinate transformations +- Integration with GPU kernel generation +- Foundation for :ref:`tile windows ` and :ref:`load/store traits ` + +Next Steps +---------- + +- :ref:`ck_tile_adaptors` - How to chain transforms together +- :ref:`ck_tile_descriptors` - Complete tensor descriptions with transforms +- :ref:`ck_tile_tile_window` - Using transforms for efficient data loading +- :ref:`ck_tile_thread_mapping` - How transforms enable thread-to-data mapping +- :ref:`ck_tile_gemm_optimization` - Practical application in GEMM kernels diff --git a/docs/conceptual/ck_tile/update_diagrams.py b/docs/conceptual/ck_tile/update_diagrams.py new file mode 100644 index 0000000000..2fbe2ef5a9 --- /dev/null +++ b/docs/conceptual/ck_tile/update_diagrams.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +""" +Helper script to update SVG diagrams from commented mermaid sources in RST files. + +This script scans RST files for commented mermaid blocks (created by convert_mermaid_to_svg.py) +and regenerates the corresponding SVG files when the source has been modified. + +Usage: + python update_diagrams.py # Update all diagrams + python update_diagrams.py # Update diagrams in a specific file +""" + +import os +import re +import subprocess +import sys +import tempfile +from pathlib import Path + +# Configuration +DOCS_DIR = Path(__file__).parent +DIAGRAMS_DIR = DOCS_DIR / "diagrams" + +# Pattern to find commented mermaid blocks followed by image references +COMMENTED_MERMAID_PATTERN = re.compile( + r"\.\.\s*\n" # Comment start + r"(?: .*\n|\s*\n)*?" # Comment description lines (may have blank lines) + r"( \.\. mermaid::\s*\n" # Commented mermaid directive + r"(?: \n| .*\n|\s*\n)*?)" # Mermaid content (including blank lines) + r"\.\. image:: diagrams/([^\s]+)", # Image reference + re.MULTILINE, +) + + +def extract_mermaid_from_comment(commented_block): + """Extract mermaid code from a commented block.""" + # Remove the comment indentation (3 spaces at start of each line) + lines = commented_block.split("\n") + content_lines = [] + + for line in lines: + if line.startswith(" "): + # Remove the 3-space comment indentation + content_lines.append(line[3:]) + elif line.strip() == "": + content_lines.append("") + + # Now we have the mermaid block, extract the actual mermaid code + mermaid_content = "\n".join(content_lines) + + # Remove the ".. mermaid::" directive and extract the indented content + mermaid_match = re.search( + r"\.\. mermaid::\s*\n((?:(?:\n| .*))*)", mermaid_content + ) + if mermaid_match: + mermaid_code = mermaid_match.group(1) + # Remove RST indentation from mermaid code + code_lines = [] + for line in mermaid_code.split("\n"): + if line.startswith(" "): + code_lines.append(line[3:]) + elif line.strip() == "": + code_lines.append("") + return "\n".join(code_lines).strip() + + return None + + +def convert_mermaid_to_svg(mermaid_code, output_path): + """Convert mermaid code to SVG using mmdc.""" + # Create a temporary file for the mermaid code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".mmd", delete=False, encoding="utf-8" + ) as tmp: + tmp.write(mermaid_code) + tmp_path = tmp.name + + try: + # Run mmdc to convert to SVG + subprocess.run( + [ + "mmdc", + "-i", + tmp_path, + "-o", + str(output_path), + "-t", + "neutral", + "-b", + "transparent", + ], + capture_output=True, + text=True, + check=True, + shell=True, # Required for Windows .cmd files + ) + return True, None + except subprocess.CalledProcessError as e: + return False, e.stderr + finally: + # Clean up temp file + os.unlink(tmp_path) + + +def process_file(file_path, force_update=False): + """Process a single RST file to update diagrams.""" + print(f"Checking {file_path.name}...") + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Find all commented mermaid blocks + matches = list(COMMENTED_MERMAID_PATTERN.finditer(content)) + + if not matches: + print(" No commented mermaid diagrams found.") + return 0, 0 + + updated_count = 0 + error_count = 0 + + for match in matches: + commented_mermaid = match.group(1) + svg_filename = match.group(2) + svg_path = DIAGRAMS_DIR / svg_filename + + # Extract mermaid code + mermaid_code = extract_mermaid_from_comment(commented_mermaid) + if not mermaid_code: + print(f" ⚠ Could not extract mermaid code for {svg_filename}") + error_count += 1 + continue + + # Check if SVG needs updating + needs_update = force_update or not svg_path.exists() + + if not needs_update: + # For a more sophisticated check, we could hash the mermaid code + # and compare with a stored hash, but for simplicity we just check existence + print(f" ✓ {svg_filename} exists (use --force to regenerate)") + continue + + # Generate SVG + success, error = convert_mermaid_to_svg(mermaid_code, svg_path) + + if success: + print(f" ✓ Updated: {svg_filename}") + updated_count += 1 + else: + print(f" ✗ Error updating {svg_filename}: {error}") + error_count += 1 + + return updated_count, error_count + + +def find_rst_files(): + """Find all RST files in the CK tile docs directory.""" + return list(DOCS_DIR.glob("*.rst")) + + +def main(): + """Main function.""" + print("CK Tile Diagram Updater") + print("=" * 50) + + # Verify mmdc is available + try: + subprocess.run( + ["mmdc", "--version"], capture_output=True, check=True, shell=True + ) + except (subprocess.CalledProcessError, FileNotFoundError): + print("Error: mermaid-cli (mmdc) not found. Please install it:") + print(" npm install -g @mermaid-js/mermaid-cli") + return 1 + + # Ensure diagrams directory exists + DIAGRAMS_DIR.mkdir(parents=True, exist_ok=True) + + # Parse command line arguments + force_update = "--force" in sys.argv or "-f" in sys.argv + specific_file = None + + for arg in sys.argv[1:]: + if arg not in ["--force", "-f"] and arg.endswith(".rst"): + specific_file = DOCS_DIR / arg + if not specific_file.exists(): + print(f"Error: File not found: {arg}") + return 1 + + # Get files to process + if specific_file: + files_to_process = [specific_file] + else: + files_to_process = find_rst_files() + + # Process files + total_updated = 0 + total_errors = 0 + + for file_path in files_to_process: + updated, errors = process_file(file_path, force_update) + total_updated += updated + total_errors += errors + + print("\n" + "=" * 50) + print("✓ Update complete!") + print(f" Updated: {total_updated} diagram(s)") + if total_errors > 0: + print(f" Errors: {total_errors}") + + return 0 if total_errors == 0 else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/docs/index.rst b/docs/index.rst index c28eb646b5..865914ab4c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab * :doc:`Composable Kernel structure <./conceptual/Composable-Kernel-structure>` * :doc:`Composable Kernel mathematical basis <./conceptual/Composable-Kernel-math>` + * :doc:`CK Tile conceptual documentation <./conceptual/ck_tile/index>` .. grid-item-card:: Tutorials diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 33ad8d91f8..c82e07ced8 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -18,6 +18,8 @@ subtrees: title: Composable Kernel structure - file: conceptual/Composable-Kernel-math.rst title: Composable Kernel mathematical basis + - file: conceptual/ck_tile/index.rst + title: CK Tile conceptual documentation - caption: Tutorial entries: From d184eed823ca50dcafc57c66228f12300c0c9ccc Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 4 Dec 2025 11:45:49 -0800 Subject: [PATCH 07/65] [CK-Tile] Refactor base pipeline usage (#3251) * initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder --- .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 73 ++---- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 61 ++--- .../gemm_weight_preshuffle_invoker.hpp | 73 ++---- example/ck_tile/03_gemm/run_gemm_example.inc | 19 +- .../03_gemm/universal_gemm_invoker.hpp | 71 ++---- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 144 +++++------- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 166 ++++++-------- .../17_grouped_gemm/grouped_gemm_multi_d.cpp | 95 +++----- .../grouped_gemm_preshuffle.cpp | 170 ++++++-------- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 137 +++++------ ...uped_convolution_backward_data_invoker.hpp | 65 +----- ...ed_convolution_backward_weight_invoker.hpp | 65 +----- ...tion_backward_weight_two_stage_invoker.hpp | 65 +----- .../grouped_convolution_forward_invoker.hpp | 198 ++++++---------- ...nvolution_forward_large_tensor_invoker.hpp | 74 +----- .../22_gemm_multi_abd/gemm_multi_abd_fp16.cpp | 141 +++++------- .../batched_contraction.cpp | 31 +-- .../test/test_bwd_data_instance_traits.cpp | 2 - .../test/test_bwd_weight_instance_traits.cpp | 2 - .../builder/test/test_fwd_instance_traits.cpp | 2 - .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 54 +++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 41 ++-- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 50 ++-- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 46 ++-- .../gemm_pipeline_ag_bg_cr_comp_v6.hpp | 54 +++-- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 47 ++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 4 - .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 43 ++-- .../batched_gemm/test_batched_gemm_util.hpp | 103 +++------ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 66 ++---- .../test_gemm_multi_abd_util.hpp | 76 ++---- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 72 ++---- .../test_gemm_pipeline_util.hpp | 68 ++---- .../grouped_gemm/test_grouped_gemm_util.hpp | 112 ++++----- .../test_grouped_gemm_multi_d_util.hpp | 80 ++----- .../test_grouped_gemm_preshuffle_util.hpp | 217 +++++++----------- tile_engine/ops/gemm/gemm_instance_builder.py | 61 +---- 37 files changed, 1012 insertions(+), 1836 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index 62744d9895..c312a53c2a 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -46,14 +46,6 @@ struct SplitKTwoStageInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using WorkspaceType = ck_tile::remove_cvref_t; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using WorkspaceType = ck_tile::remove_cvref_t; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( @@ -244,21 +217,15 @@ struct SplitKTwoStageInvoker ck_tile::make_tuple(args.N, 1), // Output Stride input_tensors, static_cast(c_ptr))); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index 74edddb6c9..abad4ab5c4 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -133,14 +133,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - // Create base GEMM arguments pointing to workspace instead of final output // The workspace will store partial results from each K-split ck_tile::GemmHostArgs base_args(args.a_ptr, @@ -179,23 +158,18 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_A, args.stride_B, args.stride_E); + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&]() { + // use SET operation since each K-split writes to separate memory + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + scheduler>; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; @@ -276,29 +250,20 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - return ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - return ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - // For workspace mode, always use SET operation since each K-split writes to separate memory - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return Run(); } /** diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index 07f449f34b..b394598110 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -33,14 +33,6 @@ struct WeightPreshuffleInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{}); - } - else - { - throw std::runtime_error("split-k is not supported yet!"); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("split-k is not supported yet!"); + } } }; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 30cb3d3476..c4f100b36b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -63,14 +63,17 @@ void permute_tensor_b(Tensor& tensor) GemmConfig::TransposeC, GemmConfig::UseStructuredSparsity>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index b9b05a8e86..0fcf9680bc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -34,14 +34,6 @@ struct UniversalInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 6838e899e6..c7e37bc8a7 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -59,7 +59,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - float ave_time{0}; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 531e437006..3ff3f2f10e 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -42,12 +42,6 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_descs, BLayout, CLayout, GemmConfig::TransposeC>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template & gemm_d GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_d BLayout, ELayout, GemmConfig::TransposeC>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& gemm_d << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template & gemm_descs, GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_descs, GemmConfig::Persistent, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = - // if preshuffle == true then num_loop is recalculated for each group in the kernel code - TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - float ave_time{0}; - - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template ; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 7638b92002..d2663b033c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardDataInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; + constexpr auto scheduler = ConvConfig::Scheduler; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardDataInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, @@ -170,26 +134,19 @@ struct GroupedConvolutionBackwardDataInvoker kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index f7171ef9d9..0891e8c20b 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardWeightInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; + constexpr auto scheduler = ConvConfig::Scheduler; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardWeightInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, @@ -176,26 +140,19 @@ struct GroupedConvolutionBackwardWeightInvoker } }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index 5d78bc4739..50c0ce4f87 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -60,42 +60,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + constexpr auto scheduler = ConvConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -105,8 +72,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, @@ -209,7 +174,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z @@ -228,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), @@ -242,22 +206,15 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ck_tile::make_tuple(shape[1], 1), // Output Stride input_tensors, static_cast(c_ptr))); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 3e1f4c6268..82541bb593 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -65,148 +65,96 @@ struct GroupedConvolutionForwardInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - // Split-K parameters - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + constexpr auto scheduler = ConvConfig::Scheduler; // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - has_hot_loop_v, - tail_number_v, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - // ===================================================================== - // Split-K lambda - // ===================================================================== - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; // ===================================================================== - // Regular Convolution Example: ALWAYS uses regular path (Kernel) + // Split-K dispatch // ===================================================================== - // This example demonstrates regular convolution without split-image. - // For large images that don't fit in memory, use - // grouped_convolution_forward_split_image.cpp - - // Launch kernel using regular path (no split-image) - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index d154d8710b..4261385a84 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -72,36 +72,6 @@ struct GroupedConvolutionForwardInvoker GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsTypeDefault::VectorSizeA, - GroupedConvTraitsTypeDefault::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - // Split-K parameters - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - using TransformType = ck_tile::TransformConvFwdToGemm{}); - else - return Run(has_hot_loop_, - tail_number_, - MemoryOpAtomicAdd{}, - ck_tile::bool_constant{}); - }; - return BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num); + if(args.k_batch == 1) + return Run(MemoryOpSet{}, ck_tile::bool_constant{}); + else + return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); } else { - const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - return Run(has_hot_loop_, - tail_number_, - MemoryOpSet{}, - ck_tile::bool_constant{}); - else - return Run(has_hot_loop_, - tail_number_, - MemoryOpAtomicAdd{}, - ck_tile::bool_constant{}); - }; - return BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num); + if(args.k_batch == 1) + return Run(MemoryOpSet{}, ck_tile::bool_constant{}); + else + return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index 5ea4299492..acb9126d65 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -63,8 +63,6 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - float ave_time{0}; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index 6536894394..f9f13c6e85 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -90,24 +90,9 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - ck_tile::index_t K_total = 1; - for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i) - { - K_total *= args.A_dims[i]; - } - - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = [&]() { constexpr auto memory_operation = ck_tile::memory_operation_enum::set; // Always set (no atomic_add) @@ -116,9 +101,7 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; + scheduler>; using GemmPipeline = GEMM_PIPELINE; @@ -166,14 +149,10 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs(Kernel{}, grids, blocks, 0, kargs); - ave_time = ck_tile::launch_kernel(s, kernel); - - return ave_time; + return ck_tile::launch_kernel(s, kernel); }; - BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); - - return ave_time; + return Run(); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index 6b18095544..80e8ae8d98 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -54,8 +54,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*InDataType*/, diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index 3ecd06e33d..9b3cd169bb 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -156,8 +156,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*WeiDataType*/, diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 9da707bfec..6a8f1f14e3 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -767,8 +767,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*OutDataType*/, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index d27f937435..0b2cdde05e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -19,12 +19,12 @@ struct BaseGemmPipelineAgBgCrCompAsync static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; - CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop == 1) { @@ -158,9 +158,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -539,14 +537,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } public: @@ -557,14 +562,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index f83462391c..d4475e8c60 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -154,10 +154,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = - Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop); - static constexpr auto TailNum = - Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop); static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; @@ -641,13 +637,20 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } /** @@ -700,13 +703,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + return operator()(a_dram_block_window_tmp, + b_dram_block_window_tmp, + num_loop, + has_hot_loop, + tail_number, + p_smem); } template static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -685,14 +683,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 void* p_smem_0, void* p_smem_1) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr index_t NumWarps = BlockGemmShape::NumWarps; static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); @@ -404,13 +402,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 index_t num_loop, void* p_smem_0) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template const index_t num_loop, void* __restrict__ p_smem_0) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem_0); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem_0); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop % HotloopUnroll == 1) { @@ -153,9 +153,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -173,11 +171,9 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 return concat('_', "pipeline_AgBgCrCompV6", BlockSize, concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', kPadM, kPadN, kPadK), - concat('x', TailNum), concat('_', KRepeat), concat('_', DoubleSmemBuffer), - concat('_', Preshuffle), - concat('_', HasHotLoop)); + concat('_', Preshuffle)); // clang-format on } @@ -725,13 +721,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 index_t num_loop, void* __restrict__ p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template const index_t num_loop, void* __restrict__ p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - // Where is the right place for HasHotLoop and TailNum ??? - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -887,13 +884,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template = DsReadPreload) ? DsReadPreload : MIterPerWarp * KIterPerWarp; - static constexpr auto TailNum = Problem::TailNum; #ifdef __gfx942__ static constexpr index_t mfma_per_wg = 2; @@ -1042,13 +1041,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 void* p_smem_ping, void* p_smem_pong) const { - return operator()( - a_dram_block_window_tmp[number<0>{}], - [](const ADataType& a) { return a; }, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem_ping, - p_smem_pong); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const ADataType& a) { return a; }; + return operator()(a_dram_block_window_tmp[number<0>{}], + PassThrough, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem_ping, + p_smem_pong); + }; + return Base::TailHandler(RunPipeline, true, tail_number); } // called from general gemm kernel @@ -1063,13 +1069,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 void* p_smem_ping, void* p_smem_pong) const { - return operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const ADataType& a) { return a; }; + return operator()(a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + }; + return Base::TailHandler(RunPipeline, true, tail_number); } // called from grouped gemm kernel diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 3c344259bb..77eb416532 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -81,7 +81,6 @@ class TestCkTileBatchedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -154,36 +135,26 @@ class TestCkTileBatchedGemm : public ::testing::Test { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index c489d3be54..a0c078a1e9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -159,8 +159,6 @@ class TestCkTileGemmPipeline : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using BaseGemmPipeline = - typename GemmPipelineTypeSelector::base_pipeline; + using GemmPipeline = + typename GemmPipelineTypeSelector::pipeline; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = - typename GemmPipelineTypeSelector::pipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index 8234692696..ee045c7f48 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -134,7 +134,6 @@ class TestCkTileGemmMultiABD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 373370b18c..8217f5a3d9 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -150,7 +150,6 @@ class TestCkTileGemmMultiD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 928c72b62d..43a73738d9 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -132,8 +132,6 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmConfig::K_Warp_Tile>>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -150,37 +148,19 @@ class TestCkTileGemmPipeline : public ::testing::Test NumWaveGroup, preshuffle>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using BaseGemmPipeline = - typename GemmPipelineTypeSelector::base_pipeline; + using GemmPipeline = + typename GemmPipelineTypeSelector::pipeline; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = - typename GemmPipelineTypeSelector::pipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index a64542aa95..db51a3e8b2 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -91,12 +91,6 @@ class TestCkTileGroupedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; - const ck_tile::index_t K_split = - (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -176,7 +151,7 @@ class TestCkTileGroupedGemm : public ::testing::Test << blocks.z << "}" << std::endl; } - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -185,29 +160,20 @@ class TestCkTileGroupedGemm : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } template diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4397668a5d..b065df6f8a 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -104,8 +104,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // for testing purposes, we can hardcode the values here as we what is compatible with // pipeline using GemmUniversalTraits = @@ -121,49 +119,24 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test /*Persistent*/ false, /*NumWaveGroups*/ 1, /*Preshuffle*/ false>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = std::conditional_t< + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< Config::Pipeline_ == (PipelineType::Memory), - ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::GemmPipelineAgBgCrMem, std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV4>>; + ck_tile::GemmPipelineAgBgCrCompV3, + ck_tile::GemmPipelineAgBgCrCompV4>>; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_; - const ck_tile::index_t K_split = - (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile_; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = std::conditional_t< - Config::Pipeline_ == (PipelineType::Memory), - ck_tile::GemmPipelineAgBgCrMem, - std::conditional_t, - ck_tile::GemmPipelineAgBgCrCompV4>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( Kernel{}, @@ -211,25 +184,18 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index c322aac575..0eb388082b 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -123,8 +123,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // for testing purposes, we can hardcode the values here as we what is compatible with // pipeline using GemmUniversalTraits = @@ -140,58 +138,37 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*Persistent*/ false, /*NumWaveGroups*/ 1, /*Preshuffle*/ true>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = - ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -204,7 +181,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test hipMemcpyHostToDevice, s.stream_id_)); - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -213,25 +190,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } private: @@ -247,8 +217,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // Enable persistent mode for preshuffle using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; - - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = - ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -327,7 +273,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test hipMemcpyHostToDevice, s.stream_id_)); - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -336,25 +282,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } public: diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index d450f20105..65fede6a5f 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -337,13 +337,6 @@ class GemmKernelBuilder: "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", } - # Map pipeline names to base pipeline for hot loop detection - base_pipeline_map = { - "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", - "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", - } - # Map scheduler names to the correct enum values scheduler_type_map = { "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", @@ -423,33 +416,10 @@ struct SelectedKernel {{ // Tile partitioner using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - - // Traits - using Traits = ck_tile::TileGemmTraits; - - // Pipeline problem - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - Traits>; - - // Base pipeline for hot loop detection - using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{{0}}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) {{ constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; @@ -462,9 +432,7 @@ struct SelectedKernel {{ ALayout, BLayout, CLayout, TransposeC, UseStructuredSparsity, UsePersistentKernel, NumWaveGroups, Preshuffle>, - scheduler, - has_hot_loop_v, - tail_number_v>; + scheduler>; using GemmPipeline = {pipeline_impl_map.get(pipeline)}; @@ -542,28 +510,23 @@ struct SelectedKernel {{ // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; - ave_time = ck_tile::launch_kernel( + float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); return ave_time; }}; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ - if(args.k_batch == 1) {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} else {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} - }}; + float ave_time = 0.f; + + if(args.k_batch == 1) {{ + ave_time = Run(ck_tile::integral_constant{{}}); + }} else {{ + ave_time = Run(ck_tile::integral_constant{{}}); + }} - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; }} }}; From d1193e8637a4ac82217d0413e67ed52700c7f8fc Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:29:14 -0800 Subject: [PATCH 08/65] fix hipblaslt build for different archs (#3358) --- Dockerfile.pytorch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 4533166c06..9628bf46fa 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx90a;gfx942;gfx950" -j 128 --skip_rocroller From 05292b3604e143e98ec2cb67edb2e3d2ad1d6ecb Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 5 Dec 2025 10:31:12 +0800 Subject: [PATCH 09/65] [CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153) * Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/ck_tile/01_fmha/CMakeLists.txt | 34 - .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 20 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 827 ++++++++++++------ .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 616 ------------- example/ck_tile/01_fmha/fmha_fwd.hpp | 94 ++ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 60 -- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 73 -- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 179 ---- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 - .../core/algorithm/coordinate_transform.hpp | 82 ++ include/ck_tile/core/tensor/tile_window.hpp | 9 +- .../ck_tile/ops/fmha/block/block_masking.hpp | 13 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 48 - .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 134 +-- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 32 +- .../pipeline/block_fmha_pipeline_enum.hpp | 1 + .../pipeline/block_fmha_pipeline_problem.hpp | 43 - .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 - include/ck_tile/remod.py | 2 +- 22 files changed, 890 insertions(+), 1449 deletions(-) delete mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 0c8102a70b..6e7d69281d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") - -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp - ${FMHA_FWD_V3_INSTANCES} -) - -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -fgpu-flush-denormals-to-zero - -Wno-undefined-func-template - --save-temps -) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) - -check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) -if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 - ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 - ) -endif() - -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 333579ec8d..a3cfe2622a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -30,16 +30,24 @@ _MASK_MAP = { } -def get_mask_map(mask: str): - if mask == "generic": +def get_mask_map(mask_impl: str): + if mask_impl == "generic": return _MASK_MAP - elif mask == "simplified": + elif mask_impl == "simplified": return _MASK_SIMPLIFIED_MAP else: assert False return None +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + _MASK_CHECK_MAP = { "no": "t.mask_type == mask_enum::no_mask", "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", @@ -62,6 +70,10 @@ def get_mask_check_map(mask: str): return None +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", @@ -122,6 +134,7 @@ PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", } PIPELINE_ENUM_MAP = { @@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = { "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 17d4f6e1d7..c00bdcea3b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -8,14 +8,13 @@ import os from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, BIAS_CHECK_MAP, - get_mask_check_map, BOOL_MAP, PIPELINE_MAP, PIPELINE_ENUM_MAP, @@ -23,6 +22,8 @@ from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, QSCALE_CHECK_MAP, QSCALE_MAP, ) @@ -48,79 +49,79 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY = """ +FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_qscale}, - {F_occupancy}, - {F_skip}>; +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_qscale}, + {F_occupancy}, + {F_skip}>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; +using fmha_mask = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, + fmha_variant, + fmha_mask, {F_trload}, - fmha_trait_{F_idx}>; + fmha_traits>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; +using fmha_kernel = {F_kernel}; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}; + using k_ = fmha_kernel; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = {F_kargs_creator}(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -130,40 +131,47 @@ float fmha_fwd_(const ck_tile::stream_config& s, fm """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" -FMHA_FWD_API = """ +FMHA_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py #include #include -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ +#include "fmha_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { int device; auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device"); return false; - }} + } - hipDeviceProp_t props{{}}; + hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device properties"); return false; - }} + } num_cus = props.multiProcessorCount; return true; -}} +} -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ +} +} // namespace +""" +FMHA_FWD_API_FUNC_TEMPLATE = """ +namespace {{ +float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -182,6 +190,28 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& {F_dispatch} return r; }} +}} // namespace +""" +FMHA_FWD_API_FOOTER_TEMPLATE = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const std::string device_name = ck_tile::get_device_name(); + + const bool is_swa = (traits.mask_type != mask_enum::no_mask) and + ((0 < args.window_size_left) or (0 < args.window_size_right)); + const bool can_dispatch_v3 = + (device_name.compare(0, 6, "gfx950") == 0) and + (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and + (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and + (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and + (args.hdim_v == 128); + if ({F_is_v3_enabled} and can_dispatch_v3) {{ + return fmha_fwd_v3(traits, args, config); + }} else {{ + return fmha_fwd_v2(traits, args, config); + }} +}} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -261,7 +291,7 @@ class FmhaFwdApiTrait: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.pipeline_tag in ["qr_async", "qr_async_trload", "qr_async_trload_v3"]: if self.spad == "t": return "true" # always support else: @@ -294,7 +324,7 @@ class FmhaFwdApiTrait: return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag == "qr_async_trload": + elif self.pipeline_tag in ["qr_async_trload", "qr_async_trload_v3"]: if self.skpad == "t": return "true" else: @@ -310,7 +340,7 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -327,7 +357,7 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -429,9 +459,8 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self): self.pool = OrderedDict() - self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: hdim = trait.hdim, trait.bn1 @@ -443,19 +472,60 @@ class FmhaFwdApiPool: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - @property - def api(self) -> str: + def get_num_traits( + self, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + + def render( + self, func_name, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> str: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any FmhaFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + per_arch = str() - for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): per_dtypes = str() - for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): per_hdim_case = str() for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( - pool_by_dtype.items() + item for item in pool_by_dtype.items() if has_traits(item[1]) ): - max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) inners = str() - for i_trait, trait in enumerate(pool_by_hdim): + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter_fn(trait)] + ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), F_arch=arch, @@ -463,8 +533,8 @@ class FmhaFwdApiPool: F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], @@ -506,10 +576,9 @@ class FmhaFwdApiPool: F_arch=arch, F_dtype_case=indent(per_dtypes), ) - if not per_arch: - # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) @dataclass @@ -548,18 +617,32 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: F_arch: ArchTrait - F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type F_mode: str # value from MODE_MAP F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline - mask_impl: str - @property - def template(self) -> str: - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "ck_tile::FmhaFwdV3Kernel" + else: + return "ck_tile::FmhaFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "fmha_fwd_v3_create_kargs_and_grids" + else: + return "fmha_fwd_create_kargs_and_grids" + + def render(self) -> str: + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], @@ -594,10 +677,12 @@ class FmhaFwdKernel: F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), ) @property @@ -644,16 +729,179 @@ class FmhaFwdKernel: ) -class KernelComponentFactoryGfx9: +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + tile: FmhaFwdTileSize + pipeline: FmhaFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], +) -> bool: + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> FmhaFwdKernel: + return FmhaFwdKernel( + F_arch=arch, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> list[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): + if ( + kernel_ctx.pipeline.F_bias != "no" + or kernel_ctx.pipeline.F_dropout == "t" + ): + False + return True + + def check_feature( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + # logits_soft_cap is only allowed if no bias + if not ( + ( + kernel_ctx.pipeline.F_logits == "t" + and kernel_ctx.pipeline.F_bias == "no" + ) + or kernel_ctx.pipeline.F_logits == "f" + ): + return False + return True + + return [check_mode, check_hdim, check_feature] + + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + def check_hdim_tile( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support + if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 != 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) + and kernel_ctx.tile.F_bm0 != 128 + ) + ): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + return False + return True + + rules.append(check_hdim_tile) + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + _AVAILABLE_PIPELINES = ( + CompatibilityRuleFactoryGfx9._AVAILABLE_PIPELINES + | frozenset({"qr_async_trload", "qr_async_trload_v3"}) + ) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactoryGfx9.get_rules() + + def check_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + ) + ): + return False + + # only qr_async_trload_v3 use km0=256 & 8-warps + is_v3_dedicated_tile = ( + kernel_ctx.tile.F_bm0 == 256 + and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 + and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 + ) # fmt: skip + is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" + return is_v3_dedicated_tile == is_v3_pipeline + + rules.extend([check_tile_pipeline]) + return rules + + +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) + _DT_FP32 = ("fp32",) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8 = ("fp8",) + _DT_FP8BF16 = ("fp8bf16",) + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return ( + cls._DT_FP32 + + cls._DT_FP16_BF16 + + cls._DT_FP8 + + cls._DT_FP8BF16 + + cls._DT_FP8FP32 + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp32"]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP32: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -666,7 +914,7 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: return { ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -682,30 +930,32 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ["fp32"]: + if dtype in cls._DT_FP32: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -718,7 +968,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -743,7 +993,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], @@ -755,21 +1005,33 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO - None - else: - assert False + pass return pipelines -class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 +): arch = ArchTrait("gfx950") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + # add tile for qr_async_trload_v3 + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + return result + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( dtype, hdim, hdim_v, receipt, mask_impl ) - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -788,15 +1050,31 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + return pipelines -class KernelComponentFactoryGfx12: +class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp16", "bf16"]: + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -805,25 +1083,27 @@ class KernelComponentFactoryGfx12: (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { # bm0, bn0, bk0, bn1, bk1, (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = [] - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -835,23 +1115,21 @@ class KernelComponentFactoryGfx12: ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - else: - assert False return pipelines -class CustomFactory(KernelComponentFactoryGfx9): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in cls._DT_FP16_BF16: if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -874,150 +1152,162 @@ def get_factory(target: str): raise Exception(f"Unsupported device target {target}") +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + # Flash attention integration + if receipt in (2, 3): + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + return cond + + return Product(name="Flash attention integration", rule=fit) + # PyTorch integration + elif receipt == 4: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="PyTorch integration", rule=fit) + # Aiter(mha_fwd) integration + elif receipt == 100: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_fwd) integration", rule=fit) + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) + # aiter::mha_fwd C++ api integration + elif receipt == 600: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + elif receipt == 888: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp8bf16", "fp8fp32"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="receipt = 888", rule=fit) + # fp32 only, all variations + elif receipt == 800: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="fp32 only, all variations", rule=fit) + # fp32 only, minimal set of parameters + elif receipt == 801: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= problem_ctx.hdim in [48, 128] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_bias == "no" + cond &= kernel_ctx.pipeline.F_lse == "f" + cond &= kernel_ctx.pipeline.F_dropout == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_mask == "s_no" + return cond + + return Product(name="fp32 only, minimal set of parameters", rule=fit) + # Don't build fp32 by default + else: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return problem_ctx.dtype != "fp32" + + return Product(name="Default", rule=fit) + + def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool() factories = get_factories_for_targets(targets, get_factory) - for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue for tile, next_tile in zip(tiles, tiles[1:]): assert next_tile.F_bm0 >= tile.F_bm0, ( "Tiles must be ordered by increasing bm0" ) + for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue - if factory.arch.name.startswith("gfx9") and dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = FmhaFwdKernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) + kernel_ctx = KernelContext( + tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + product = get_product(receipt) + + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): + continue + + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - cond &= pipeline.F_skip == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - cond &= mode == "batch" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - elif receipt == 888: - cond = dtype in ["fp8bf16", "fp8fp32"] - cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == "fp32" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == "fp32" - cond &= hdim in [48, 128] - cond &= mode == "batch" - cond &= pipeline.F_bias == "no" - cond &= pipeline.F_lse == "f" - cond &= pipeline.F_dropout == "f" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - cond &= pipeline.F_mask == "s_no" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == "fp32": - continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -1026,11 +1316,34 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + update_file(autogen_dir / kernel.filename, kernel.render()) -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) +def write_fwd_api( + api_pool: FmhaFwdApiPool, + autogen_dir: Path, +) -> None: + def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: + return trait.pipeline_tag == "qr_async_trload_v3" + + def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: + return not accept_only_v3(trait) + + content = "".join( + [ + FMHA_FWD_API_HEADER, + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + # NOTE: enable v3 pipelines when ready + # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + False + ] + ), + ] + ) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) def write_blobs( diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp deleted file mode 100644 index c510b36bb5..0000000000 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -auto parse_cmd_args(int argc, char* argv[]) -> std::pair -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q & k") - .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "0", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "0", "permute output") - .insert("causal", "0", "0: no mask, 1: causal mask") - .insert("v", "1", "0:no verify, 1:verify") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "30", "number of iterations to benchmark the kernel") - // Optional effective seqlen override (exclude PAD) for batch mode - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); - - bool result = arg_parser.parse(argc, argv); - return std::make_pair(result, arg_parser); -} - -enum class TensorLayout -{ - bhsd, - bshd, -}; - -std::ostream& operator<<(std::ostream& stream, TensorLayout layout) -{ - switch(layout) - { - case TensorLayout::bhsd: return stream << "bhsd"; - case TensorLayout::bshd: return stream << "bshd"; - default: return stream << "unknown"; - } -} - -struct Problem -{ - explicit Problem(const ck_tile::ArgParser& args) - { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); - if(seqlen_k < 0) - { - seqlen_k = seqlen_q; - } - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); - if(nhead_kv < 0) - { - nhead_kv = nhead_q; - } - hdim = args.get_int("d"); - softmax_scale = args.get_float("scale_s"); - if(softmax_scale == .0f) - softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); - } - - std::vector get_query_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::vector get_key_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_value_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_output_shape() const - { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; - ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_kv; - ck_tile::index_t hdim; - float softmax_scale; - mask_info mask; - TensorLayout input_layout; - TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; -}; - -struct RunConfig -{ - explicit RunConfig(const ck_tile::ArgParser& args) - { - seed = args.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - kernel_warmup = args.get_int("warmup"); - kernel_repeat = args.get_int("repeat"); - verify = args.get_bool("v"); - } - - std::optional seed; - int kernel_warmup; - int kernel_repeat; - bool verify; -}; - -template -auto generate_qkv(const Problem& problem, - [[maybe_unused]] std::optional seed = std::nullopt) - -> std::tuple, - ck_tile::HostTensor, - ck_tile::HostTensor> -{ - ck_tile::HostTensor q(problem.get_query_shape()); - ck_tile::HostTensor k(problem.get_key_shape()); - ck_tile::HostTensor v(problem.get_value_shape()); - - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); - - return std::make_tuple(q, k, v); -} - -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); - - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host - -template -bool run_impl(const Problem& problem, const RunConfig& run_config) -{ - auto [q, k, v] = generate_qkv(problem, run_config.seed); - - ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); - /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v - ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q.data()); - k_buf.ToDevice(k.data()); - v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - - ck_tile::fmha_fwd_v3_args args{}; - - args.data_type = problem.data_type; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_kv = problem.nhead_kv; - args.hdim_qk = problem.hdim; - args.hdim_v = problem.hdim; - args.softmax_scale = problem.softmax_scale; - - args.window_size_left = problem.mask.left; - args.window_size_right = problem.mask.right; - args.mask_type = static_cast(problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - - ck_tile::stream_config stream_config{nullptr, - true, - /*log_level=*/0, - run_config.kernel_warmup, - run_config.kernel_repeat}; - - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); - if(!result) - { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; - return false; - } - - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); - float tflops = static_cast(flop) / 1.e9 / time; - - std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; - - if(!run_config.verify) - { - return true; - } - - // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } - - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } - - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); - - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); -} - -int main(int argc, char* argv[]) -{ - auto [parse_result, args] = parse_cmd_args(argc, argv); - if(!parse_result) - { - std::cerr << "failed to parse command line arguments" << std::endl; - } - - Problem problem(args); - RunConfig run_config(args); - - const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) - { - return run_impl(problem, run_config); - } - else - { - return run_impl(problem, run_config); - } - }; - - return !run(); -} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f279ebfcea..002d0a1035 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) +{ + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + template auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp deleted file mode 100644 index 1c0256cc0f..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" -#include "mask.hpp" - -namespace ck_tile { - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) -{ - switch(data_type) - { - case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; - case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; - default: return stream << "unknown"; - } -} - -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) -{ - if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - - return std::make_pair(false, -1.f); -} - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp deleted file mode 100644 index 54cc4960a5..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/host/stream_config.hpp" - -namespace ck_tile { - -struct fmha_fwd_v3_args -{ - enum class data_type_enum - { - fp16, - bf16 - }; - - data_type_enum data_type; - // bool is_varlen; - - index_t batch; - index_t seqlen_q; - index_t seqlen_k; - index_t nhead_q; - index_t nhead_kv; - index_t hdim_qk; - index_t hdim_v; - - float softmax_scale; - - index_t window_size_left; - index_t window_size_right; - index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and - // window_size_right == 0). - - const void* q_ptr; - index_t stride_q; - index_t nhead_stride_q; - index_t batch_stride_q; - - const void* k_ptr; - index_t stride_k; - index_t nhead_stride_k; - index_t batch_stride_k; - - const void* v_ptr; - index_t stride_v; - index_t nhead_stride_v; - index_t batch_stride_v; - - void* o_ptr; - index_t stride_o; - index_t nhead_stride_o; - index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] -}; - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp deleted file mode 100644 index 19b8dfed4e..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/container/sequence.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" - -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ - } - -namespace ck_tile { - -template -struct fmha_fwd_v3_problem_traits; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::bf16_t; - using acc_dtype = float; - using o_dtype = ck_tile::bf16_t; - using lse_dtype = float; -}; - -template -struct fmha_fwd_v3_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; - using fmha_warp_gemm_shape = sequence<32, 32, 16>; - using fmha_block_warps = sequence<8, 1, 1>; - - using fmha_shape = TileFmhaShape; - - using fmha_traits = TileFmhaFwdV3Traits; - - using fmha_mask = GenericAttentionMask; - - using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - fmha_shape, - IsVariableSeqlen, - fmha_mask, - fmha_traits>; - - using fmha_pipeline = BlockFmhaFwdV3Pipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using kernel = FmhaFwdV3Kernel; -}; - -template -float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) -{ - /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly - /// maximizes the kernel's performance. - int remap_opt = 2; - if(args.mask_type != static_cast(mask_enum::no_mask) && - ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) - { - if(65536 <= args.seqlen_q) - { - remap_opt = 0; - } - else - { - remap_opt = 1; - } - } - - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_qk, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_kv, - args.softmax_scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; - - return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template -std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, - const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp deleted file mode 100644 index 463c52b824..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp deleted file mode 100644 index acf79e43f4..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp deleted file mode 100644 index a6366209b2..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp deleted file mode 100644 index a83e37cc68..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 81eea60c2f..29a7e2593e 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing& printf("}"); } +template +struct functor_transform : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + Functor functor_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr functor_transform() = default; + + CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor, + const LowLength& low_length) + : functor_{functor}, up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = functor_(idx_up[number<0>{}]); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + calculate_lower_index(idx_low, up_idx); + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // Note: When using functor_transform, ensure that the transformed coordinates + // are always valid for vectorized load/store operations. + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } +}; + //******************************************************************************************************* template @@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le return offset{low_length, offset_length}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor, + const LowLength& low_length) +{ + return functor_transform{functor, low_length}; +} + } // namespace ck_tile #include "ck_tile/core/algorithm/indexing_adaptor.hpp" diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index e80267faec..d39da82a62 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths } }; -template +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1a79aebef5..756968871d 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask mdiv y_ratio_mdiv; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..9890d1f2e4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -73,54 +73,6 @@ struct FmhaFwdKernel #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + - (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)) + (kUseTrLoad ? "_trload" : "_ntrload"); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index df17bdd879..f981c54bd8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,6 +12,8 @@ namespace ck_tile { +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct FmhaFwdV3Kernel { @@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; // Optional cumulative padded sequence starts (including PAD tokens) // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasMask) @@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - // TODO: this may need tuning - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1)); } else { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + batch_size); } } @@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel // FmhaPipeline::kN1); // assume that num_tile_n1 is always 1 - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { const index_t i_nhead = blockIdx.x; - const index_t i_block = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.z; - return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } else { @@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel const index_t i_block = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } } @@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + batch_offset_o = query_start * kargs.stride_o; + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 8bf24be386..68ec349694 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; + static_assert(is_generic_attention_mask_v); static_assert(std::is_same_v, "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); using BlockFmhaShape = ck_tile::remove_cvref_t; + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; @@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; + static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kStoreLSE && !kHasDropout && + (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && + !kSkipMinSeqlenQ), + "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index da0fa16ee1..659bdd995b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_ASYNC_TRLOAD_V3, }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index b90b760a0d..7c4a921b70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdV3PipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index b9e18de1e5..df33a93696 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdV3Traits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - } // namespace ck_tile diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index affa6d987b..aeec7bd471 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -90,7 +90,7 @@ submodule = submodule_t() # formatting format_procs = [] for x in all_files: - dos2unix = f"python -m dos2unix {str(x)} {str(x)}" + dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}" clang_format = f"clang-format -style=file -i {str(x)}" # One process to avoid race conditions. cmd = f"{dos2unix} && {clang_format}" From 13f6d635653bd5ffbfcac8577f1ef09590c23d78 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Thu, 4 Dec 2025 19:12:36 -0800 Subject: [PATCH 10/65] Clean up conv_traits.hpp (#3354) When I asked for a description of operators that didn't have ConvTraits, I was getting very long confusing errors about ConvTraits not being defined. Now we get specific errors explaining which concepts are violated, making it easier to know which code to generalize or update. * Add concepts to conv_traits.hpp to get better error message. * Put the correct requires clauses in the right places to get descriptive error messages. * General cleanup of functions in conv_traits.hpp to make functions easier to read. --- .../builder/reflect/conv_description.hpp | 8 +- .../ck_tile/builder/reflect/conv_traits.hpp | 457 +++++++----------- 2 files changed, 186 insertions(+), 279 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 261c3f103d..59ff83c238 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -251,14 +251,10 @@ class ConvDescription : public Description }; } // namespace conv -/// @brief Helper concept to detect if a type has ConvTraits specialization -template -concept HasConvTraits = requires { typename conv::ConvTraits; }; - /// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have InstanceTraits specialization) +/// @tparam Instance The convolution instance type (must have ConvTraits specialization) /// @return A ConvDescription object populated with the instance's configuration details -template +template conv::ConvDescription describe() { using Traits = conv::ConvTraits; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 29ac49e549..918fd6bdb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -21,6 +21,57 @@ namespace ck_tile::reflect::conv { +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// GEMM specialization concept - checks for kGemmSpecialization member +template +concept HasGemmSpec = requires { + { + T::kGemmSpecialization + } -> std::convertible_to; +}; + +// Data types concept - checks for ADataType member +template +concept HasDataTypes = requires { typename T::ADataType; }; + +// Elementwise operations concept - checks for A/B/CDE elementwise operation types +template +concept HasElementwiseOps = requires { + typename T::AElementwiseOperation; + typename T::BElementwiseOperation; + typename T::CDEElementwiseOperation; +}; + +// Tile parameters concept - checks for tile dimension and transfer members +template +concept HasTileParams = requires { + { T::kKPerBlock } -> std::convertible_to; + { T::kMPerBlock } -> std::convertible_to; + { T::kNPerBlock } -> std::convertible_to; + { T::kAK1 } -> std::convertible_to; + { T::kBK1 } -> std::convertible_to; + T::kCThreadClusterLengths; +}; + +// Comprehensive concept that checks if an instance has all XDL forward convolution traits +// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions +template +concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && + HasElementwiseOps && HasTileParams; + +// Primary concept for checking if a type can be described +// Currently only forward convolutions are supported, but this can be extended +// in the future to include backward data and backward weight convolutions +template +concept HasConvTraits = IsXdlFwdConv>; + // Helper metafunctions to convert from ck enums to builder enums /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. @@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version() { using enum ck::BlockGemmPipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v3) - return V3; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == v5) - return V5; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } } /// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. @@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version() { using enum ck::PipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == weight_only) - return WEIGHT_ONLY; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } } /// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. @@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::BlockGemmPipelineScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Intrawave) - return INTRAWAVE; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } } /// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. @@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::LoopScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Default) - return DEFAULT; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } } /// @brief Helper structures for organizing trait data with domain-specific naming @@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction() using InstTraits = InstanceTraits; if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - { return builder::ConvDirection::FORWARD; - } else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - { return builder::ConvDirection::BACKWARD_DATA; - } else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - { return builder::ConvDirection::BACKWARD_WEIGHT; - } else - { return builder::ConvDirection::FORWARD; // Default fallback - } } /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. @@ -242,60 +288,52 @@ constexpr auto conv_spec() if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvFwdSpecialization; - if constexpr(InstTraits::kConvForwardSpecialization == Default) + switch(InstTraits::kConvForwardSpecialization) { - return builder::ConvFwdSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) - { - return builder::ConvFwdSpecialization::FILTER_3x3; + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; } } else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvBwdDataSpecialization; - if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + switch(InstTraits::kConvBwdDataSpecialization) { - return builder::ConvBwdDataSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; } } else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvBwdWeightSpecialization; - if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + switch(InstTraits::kConvBwdWeightSpecialization) { - return builder::ConvBwdWeightSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) - { - return builder::ConvBwdWeightSpecialization::ODD_C; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; } } } +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return An std::array corresponding to the tensor layouts: @@ -304,112 +342,49 @@ constexpr auto conv_spec() /// index 2 -> Output layout template constexpr auto conv_layout() + requires HasFwdConvLayouts> { - using InstTraits = InstanceTraits; - using ALayout = typename InstTraits::ALayout; - using BLayout = typename InstTraits::BLayout; - using ELayout = typename InstTraits::ELayout; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - namespace ctc = ck::tensor_layout::convolution; + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(InstTraits::kSpatialDim == 1) + switch(InstanceTraits::kSpatialDim) { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNWC, - builder::TensorLayout::GKXC, - builder::TensorLayout::GNWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NWGC, - builder::TensorLayout::GKXC, - builder::TensorLayout::NWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKXC, - builder::TensorLayout::NGKW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKCX, - builder::TensorLayout::NGKW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 2) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNHWC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::GNHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NHWGC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NGKHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKCYX, - builder::TensorLayout::NGKHW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 3) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNDHWC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::GNDHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NDHWGC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NDHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NGKDHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKCZYX, - builder::TensorLayout::NGKDHW}; - } + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } } @@ -418,39 +393,26 @@ constexpr auto conv_layout() /// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). template constexpr builder::DataType conv_data_type() + requires HasDataTypes> { using InstTraits = InstanceTraits; using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; if constexpr(std::is_same_v) - { - return builder::DataType::FP16; - } + return FP16; else if constexpr(std::is_same_v) - { - return builder::DataType::BF16; - } + return BF16; else if constexpr(std::is_same_v) - { - return builder::DataType::FP32; - } + return FP32; else if constexpr(std::is_same_v) - { - return builder::DataType::FP8; - } + return FP8; else if constexpr(std::is_same_v) - { - return builder::DataType::I8; - } + return I8; else if constexpr(std::is_same_v) - { - return builder::DataType::U8; - } + return U8; else - { - // Default fallback - return builder::DataType::FP32; - } + return FP32; // Default fallback } /// @brief Derives the elementwise operation from op type. @@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type() template constexpr builder::ElementwiseOperation elementwise_op() { + using enum builder::ElementwiseOperation; constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - { - return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - { - return builder::ElementwiseOperation::CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - { - return builder::ElementwiseOperation::SCALE; - } - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - { - return builder::ElementwiseOperation::PASS_THROUGH; - } - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - { - return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU; - } + return BIAS_BNORM_CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; } /// @brief Derives a gemm padding from a kernel instance type. @@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op() /// @return A `builder::GemmPadding` enum value corresponding to kernel padding. template constexpr builder::GemmPadding gemm_spec() + requires HasGemmSpec> { using InstTraits = InstanceTraits; using enum builder::GemmPadding; @@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec() constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - if constexpr(gemm_spec == Default) + switch(gemm_spec) { - return DEFAULT; - } - else if constexpr(gemm_spec == MPadding) - { - return M_PADDING; - } - else if constexpr(gemm_spec == NPadding) - { - return N_PADDING; - } - else if constexpr(gemm_spec == KPadding) - { - return K_PADDING; - } - else if constexpr(gemm_spec == MNPadding) - { - return MN_PADDING; - } - else if constexpr(gemm_spec == MKPadding) - { - return MK_PADDING; - } - else if constexpr(gemm_spec == NKPadding) - { - return NK_PADDING; - } - else if constexpr(gemm_spec == MNKPadding) - { - return MNK_PADDING; - } - else if constexpr(gemm_spec == OPadding) - { - return O_PADDING; - } - else if constexpr(gemm_spec == MOPadding) - { - return MO_PADDING; - } - else if constexpr(gemm_spec == NOPadding) - { - return NO_PADDING; - } - else if constexpr(gemm_spec == KOPadding) - { - return KO_PADDING; - } - else if constexpr(gemm_spec == MNOPadding) - { - return MNO_PADDING; - } - else if constexpr(gemm_spec == MKOPadding) - { - return MKO_PADDING; - } - else if constexpr(gemm_spec == NKOPadding) - { - return NKO_PADDING; - } - else if constexpr(gemm_spec == MNKOPadding) - { - return MNKO_PADDING; + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; } } @@ -571,6 +481,7 @@ struct ConvTraits; /// set of traits directly from a fully-formed device kernel `Instance` type. /// It uses `InstanceTraits` to access the kernel's template parameters. template + requires IsXdlFwdConv> struct ConvTraits { using InstTraits = InstanceTraits; From f7650ee82b306a05d9c3c44d3feefdd570a4bd58 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 5 Dec 2025 09:30:22 +0100 Subject: [PATCH 11/65] fix enforcing fixedvectorsizes for ck tile conv (#3344) --- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d843916f5e..76341af70b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { using AsLayout = remove_cvref_t; using AsDataType = remove_cvref_t; @@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t{}, AsLayout>>; using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeA; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeB; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize Date: Fri, 5 Dec 2025 16:14:52 +0100 Subject: [PATCH 12/65] Add new section to changelog (#3295) * Add new section to changelog * Update CHANGELOG.md Co-authored-by: spolifroni-amd --------- Co-authored-by: spolifroni-amd --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b07e322fe1..a50303113d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). +## (Unreleased) Composable Kernel 1.3.0 + +### Added +* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. + +### Changed + +### Upcoming changes + ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added From f5b0af22722b130f03cac590ca9b8729b1b84991 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 5 Dec 2025 07:44:10 -0800 Subject: [PATCH 13/65] Simplify includes for CK builder reflection (#3357) We only want to import enums and types into the builder reflection code. But, some of the enums are included in much larger files or even big trees of include files. This leads to unintended mixing of code and very confusing interactions and symbol conflicts. We organize the includes and extract two new enum-only headers to help with decoupling in CK. This refactoring is critical if we want to include reflection in a device-operator "describe" method. * Remove a few unnecessary includes from headers in builder/reflect/. * Extract enums scheduler and pipeline to their own headers so they can be used without importing other code. * Order includes alphabetically for better organization. The immediate goal is to unblock reflection integration, and this type of cleanup helps the flexibility and robustness of the CK header library. --- .../ck_tile/builder/reflect/conv_traits.hpp | 26 +++--- .../builder/reflect/instance_traits_util.hpp | 42 +++++----- .../test/test_bwd_data_instance_traits.cpp | 7 +- .../test/test_bwd_weight_instance_traits.cpp | 10 ++- .../builder/test/test_fwd_instance_traits.cpp | 22 ++--- .../test/test_instance_traits_util.cpp | 18 ++-- .../grid/gridwise_gemm_pipeline_selector.hpp | 27 +----- include/ck/utility/blkgemmpipe_scheduler.hpp | 44 +--------- include/ck/utility/loop_scheduler.hpp | 28 +------ include/ck/utility/pipeline_enum.hpp | 40 +++++++++ include/ck/utility/scheduler_enum.hpp | 83 +++++++++++++++++++ 11 files changed, 197 insertions(+), 150 deletions(-) create mode 100644 include/ck/utility/pipeline_enum.hpp create mode 100644 include/ck/utility/scheduler_enum.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 918fd6bdb6..e5a5638887 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -4,20 +4,20 @@ #pragma once #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_builder.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" #include "ck_tile/ops/epilogue.hpp" -#include +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 64996f96f7..1055cbc038 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -8,28 +8,30 @@ #pragma once #include -#include -#include -#include -#include -#include -#include #include -#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ck_tile/ops/epilogue.hpp" +#include +#include +#include +#include +#include +#include +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index 80e8ae8d98..f26b5d7caf 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -2,9 +2,10 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index 9b3cd169bb..c7c4e370e2 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6a8f1f14e3..396533cef4 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -1,17 +1,19 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp index 42810ace72..852174b805 100644 --- a/experimental/builder/test/test_instance_traits_util.cpp +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -1,16 +1,16 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" namespace ck_tile::reflect::detail { namespace { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 8d45b8fd74..751608299c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -5,24 +5,16 @@ #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include -#include #endif +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" namespace ck { -enum struct PipelineVersion -{ - v1, - v2, - // v3 is only used in the Stream-K implementation. - v4, - weight_only, -}; - template Prefetch stages, number of loop is multiple of unroll stages - Empty, - // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add - // prefetchstages - Full, -}; - enum SchedulerGroup : uint32_t { SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index f186d0fea9..b3303e1138 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -3,40 +3,20 @@ #pragma once -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include -#endif - #include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - +/// @brief Helper function to get default loop scheduler +/// @details Returns the default loop scheduler based on compile-time configuration. constexpr LoopScheduler make_default_loop_scheduler() { #if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING return LoopScheduler::Interwave; #else return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +#endif } } // namespace ck - -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) -{ - switch(s) - { - case ck::LoopScheduler::Default: os << "Default"; break; - case ck::LoopScheduler::Interwave: os << "Interwave"; break; - default: os << ""; - } - return os; -} -#endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp new file mode 100644 index 0000000000..4421386f59 --- /dev/null +++ b/include/ck/utility/pipeline_enum.hpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Pipeline version enumeration for GEMM kernels +/// @details Defines different pipeline strategies for data movement and computation overlap +/// in GEMM kernels. This is a lightweight header containing only the enum definition, +/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies. +enum struct PipelineVersion +{ + v1, ///< Version 1 pipeline + v2, ///< Version 2 pipeline + // v3 is only used in the Stream-K implementation. + v4, ///< Version 4 pipeline + weight_only, ///< Weight-only specialized pipeline +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} +#endif diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp new file mode 100644 index 0000000000..0c4bfabaf3 --- /dev/null +++ b/include/ck/utility/scheduler_enum.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Block GEMM pipeline version enumeration +/// @details Defines different block GEMM pipeline strategies. +/// This is a lightweight header containing only enum definitions, +/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies. +enum struct BlockGemmPipelineVersion +{ + // For GEMM + v1, ///< Naive pipeline + v2, ///< Memory-optimized pipeline + v3, ///< Compute-optimized pipeline + v4, ///< Compute-optimized with double LDS buffer + v5, ///< Compute-optimized with double global prefetch register buffer + + // For GEMM with preshuffled weight + // v1, single lds buffer + // v2, double lds buffer +}; + +/// @brief Block GEMM pipeline scheduler enumeration +/// @details Defines scheduling strategies for block GEMM pipelines. +enum struct BlockGemmPipelineScheduler +{ + Intrawave, ///< Schedule within a single wavefront + Interwave, ///< Schedule across multiple wavefronts +}; + +/// @brief Loop scheduler enumeration +/// @details Defines scheduling strategies for computational loops. +enum struct LoopScheduler +{ + Default, ///< Default scheduling strategy + Interwave, ///< Cross-wavefront scheduling +}; + +/// @brief Tail number enumeration for pipeline buffering +/// @details Defines the number of tail iterations in pipelined loops. +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, ///< Odd number of iterations + Even, ///< Even number of iterations + + // Long prefetch pipeline, up to 8 + One, ///< One tail iteration + Two, ///< Two tail iterations + Three, ///< Three tail iterations + Four, ///< Four tail iterations + Five, ///< Five tail iterations + Six, ///< Six tail iterations + Seven, ///< Seven tail iterations + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, ///< No tail iterations + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, ///< Full tail iterations +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +{ + switch(s) + { + case ck::LoopScheduler::Default: os << "Default"; break; + case ck::LoopScheduler::Interwave: os << "Interwave"; break; + default: os << ""; + } + return os; +} +#endif From 82f796a1f096219da34d614b6084a03bb23f8dc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 5 Dec 2025 17:20:46 +0100 Subject: [PATCH 14/65] Profile resnet layout fixes (#3360) --- .../include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp | 4 ++-- profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp index 3cda620831..47a12e2d88 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp index 2a7ee6fd66..ac7ab78ed7 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; From 7541d9b5b0e0ce241eb75476e3ef5d61ba019210 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 5 Dec 2025 08:26:00 -0800 Subject: [PATCH 15/65] Ignore .cmake-format.yaml (#3356) We don't want to add cmake formatting until we are in the super repo, but its handy if developers want to experiment with formatting. For now we should ignore .cmake-format.yaml. --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 2641a661d8..d8468cf24e 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ tags # Editors .vscode +# CMake formatting configuration (local) +.cmake-format.yaml + # Cline .cline* From ed080f5a56c38caea8fedbd0bcc2919ba2376a6f Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:35:27 -0700 Subject: [PATCH 16/65] Congma/ck tile/aquant mem pipeline (#3346) * [CK TILE GEMM QUANT] Fix the bug in HotLoopTail of memory pipeline --- .../run_gemm_quant_example.inc | 11 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 8 +- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 151 ++++++++++++++---- 3 files changed, 127 insertions(+), 43 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 396a54c7c2..0ee19b4a26 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 71e0ebb957..38a22e38ac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -36,17 +36,13 @@ struct BaseGemmPipelineAgBgCrMem // TODO: Is this 32K value gfx9 arch specific? static constexpr index_t MinMemInFlyBytes = 32768; - static constexpr index_t WgpPerCU = - (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; + static constexpr index_t WgpPerCU = ck_tile::max(4 * get_warp_size() / BlockSize, 1); static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, (MPerBlock * sizeof(ADataType) / APackedSize + NPerBlock * sizeof(BDataType) / BPackedSize) * KPerBlock); - static constexpr index_t PrefetchStages = - FullMemBandPrefetchStages >= 2 - ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 - : 2; + static constexpr index_t PrefetchStages = ck_tile::clamp(FullMemBandPrefetchStages, 2, 8); static constexpr index_t LocalPrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index f3c8b7a1a3..7f89d98349 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -80,6 +80,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +168,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = PipelineImplBase; + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + template const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - index_t m, + [[maybe_unused]] index_t m, index_t num_loop, void* p_smem) const { - (void)m; // unused variable static_assert( std::is_same_v> && std::is_same_v std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); - static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], - "Aq block window has incorrect lengths for defined AqLayout!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -217,7 +228,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); @@ -249,7 +260,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -272,7 +283,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); // Global prefetch initialization - DRAM to VGPRs - Base::GlobalPrefetch( + LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); @@ -282,10 +293,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS prefill - VGPRs to LDS - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -293,10 +304,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -306,9 +317,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Additional prefetching for memory pipeline - DRAM to VGPRs static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -325,16 +336,17 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, aq_block_tiles.get(number{}), a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); // Prepare next iteration data - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -348,7 +360,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -365,9 +377,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_element_func); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -381,20 +393,89 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Tail handling - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - block_gemm( - c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + auto HotLoopTail = [&](auto tail_num) { + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + // no second block_sync_lds because it's interwave - if constexpr(TailNum == TailNumber::Even) - { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm( - c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window); + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); } return c_block_tile; } @@ -413,7 +494,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, From 608232ce82636e7c9ab8dec55dc7507c6792fb65 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 5 Dec 2025 08:39:18 -0800 Subject: [PATCH 17/65] do not build hipblaslt for gfx90a to save time and disc space (#3362) --- Dockerfile.pytorch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 9628bf46fa..2d3856fa2d 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx90a;gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller From 6b1bceca7baea62941793e562d6ff58c571d9191 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 5 Dec 2025 09:57:52 -0800 Subject: [PATCH 18/65] [CK_Tile] Enable PreshuffleB for 2d block scale Gemm (#3298) * formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * fixing bq tensor calculation --------- Co-authored-by: Cong Ma Co-authored-by: Thomas Ning --- .../gemm_bquant_quantgrouped_preshuffleb.cpp | 192 ++++++++++++++++-- .../run_gemm_quant_example.inc | 27 ++- include/ck_tile/host/tensor_shuffle_utils.hpp | 10 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 13 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 1 - .../test_gemm_quant_bquant_preshuffle.cpp | 44 +++- .../test_gemm_quant_fixtures.hpp | 6 +- 7 files changed, 257 insertions(+), 36 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index 8ebf5bbd96..b32356c29d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -14,36 +14,154 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -52,10 +170,50 @@ void bquant_quantgrouped_preshuffleb_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 0ee19b4a26..8a0dd9bc08 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -140,6 +140,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::WPQuantBPipelineAgBgCrV2, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + constexpr bool TiledPermuteN = + (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; + if(s.log_level_ > 0) + { + printf( + "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); + } using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -382,7 +389,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); } } - ck_tile::index_t AQK, BQK; + ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize @@ -392,6 +399,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) @@ -431,7 +439,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQ = 0; // No A quantization - stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout)); + stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout)); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -471,7 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { @@ -557,7 +565,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, b_k_n.SetZero(); bq_tensor_ptr->SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -610,7 +617,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PreshuffleB) { - if constexpr(GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -635,11 +642,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN && + QuantGroupSize::kN == 1) { - printf("Preshuffle BQ with TiledMMAPermuteN \n"); ck_tile::HostTensor bq_permuted_host = - ck_tile::bq_permuteN(*bq_tensor_ptr); + ck_tile::bq_permuteN(*bq_tensor_ptr, QuantGroupSize::kN); if constexpr(GemmConfig::PreshuffleQuant) { @@ -659,7 +666,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else + { bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); + } } invoke_gemm* t, int block_aq_k) } int m_ = t->get_lengths()[0]; int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) { throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); @@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor& t) } template -auto bq_permuteN(const ck_tile::HostTensor& t) +auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); @@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor& t) int bqk_ = t.get_lengths()[0]; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - ck_tile::HostTensor t_view( - {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_}); + ck_tile::HostTensor t_view({n_ / (GemmConfig::N_Tile / group_n), + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile / group_n, + NRepeat, + bqk_}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index b54a93614a..58b713cb35 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using QuantGroupSize = remove_cvref_t; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); - static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg } else { - constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = cvt_scale_to_fp32(scale_reg); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index f6cf4ce9be..dd85705cf2 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -747,7 +747,6 @@ struct QuantGemmKernel (splitk_batch_offset.splitted_k / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( b_ptr, make_tuple(kFlatN, kFlatK), diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 59b267842f..6cde4bded5 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -19,6 +19,12 @@ using PkInt4 = ck_tile::pk_int4_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + // Type combinations for BQuant tests with PreshuffleB // Tuple format: @@ -37,7 +43,43 @@ using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // //2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 3b62d8073e..7b16529aa8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -433,7 +433,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; if constexpr(PreshuffleB) { - if constexpr(TiledMMAPermuteN) + if constexpr(TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -451,11 +451,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase bq_shuffle_host = - ck_tile::bq_permuteN(bq_bqk_bqn); + ck_tile::bq_permuteN(bq_bqk_bqn, QuantGroupSize::kN); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } else if constexpr(GemmConfig::PreshuffleQuant) From 86a84ae61122b8ed2d2e40e45f108a8fa23d3210 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 5 Dec 2025 14:18:30 -0800 Subject: [PATCH 19/65] Add the gfx1011 support on CK Tile with the SGPR builtin reading protection (#3350) * Finish the fixes * add the gfx1010 support macro * Fix the compilation error --- include/ck_tile/core/config.hpp | 7 ++++ .../core/tensor/tile_scatter_gather.hpp | 3 +- .../core/tensor/tile_window_linear.hpp | 3 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 37 +++++++++++++++---- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index de97b46336..678a2fbfff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state #endif // __gfx950__ // GFX10 +#if defined(__gfx1010__) + static constexpr bool CK_TILE_ARCH_GFX1010 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1010 = false; +#endif + #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; #else @@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 97a44f38e8..7a4da64c4a 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -533,7 +533,8 @@ struct tile_scatter_gather size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using Traits = load_store_traits; diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 815c1bf158..6c84122d01 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -517,7 +517,8 @@ struct tile_window_linear size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using vector_t = typename Base::Traits::vector_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d83338fbb2..51f0f5f1b1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) constexpr index_t Bload_inst = (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), QuantGroupSize::kK * QuantGroupSize::kK), VectorLoadSize); - constexpr index_t kLdsVec = 8; + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; - constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; - constexpr index_t ds_write_inst = Aload_inst; - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); - constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma - static_for<0, nloop, 1>{}([&](auto j_inst) { - ignore = j_inst; + static_for<0, nloop, 1>{}([&](auto) { static_for<0, mfma_inst, 1>{}([&](auto i_inst) { __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) { __builtin_amdgcn_sched_group_barrier( @@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read } } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU }); }); From 8fec8054b2473bc7e367adf009f40a9d3fcc52df Mon Sep 17 00:00:00 2001 From: yinglu Date: Mon, 8 Dec 2025 16:24:20 +0800 Subject: [PATCH 20/65] ck: add tf32 in `DTYPES` to control instances build(#3317) --- CHANGELOG.md | 1 + CMakeLists.txt | 17 +++++++++++++++ README.md | 2 +- client_example/CMakeLists.txt | 12 +++++++++++ include/ck/config.h.in | 11 ++++++++++ .../gpu/grouped_convolution_backward_data.hpp | 16 ++++++++------ ...ped_convolution_backward_data_bilinear.hpp | 20 +++++++++++------- ...rouped_convolution_backward_data_scale.hpp | 20 +++++++++++------- .../grouped_convolution_backward_weight.hpp | 16 ++++++++------ ...d_convolution_backward_weight_bilinear.hpp | 11 +++++++--- ...uped_convolution_backward_weight_scale.hpp | 10 ++++++--- .../gpu/grouped_convolution_forward.hpp | 17 +++++++++------ ...d_convolution_forward_bias_bnorm_clamp.hpp | 16 ++++++++------ ...grouped_convolution_forward_bias_clamp.hpp | 18 +++++++++------- .../grouped_convolution_forward_bilinear.hpp | 10 ++++++--- .../gpu/grouped_convolution_forward_clamp.hpp | 17 ++++++++------- .../gpu/grouped_convolution_forward_scale.hpp | 10 ++++++--- .../gpu/CMakeLists.txt | 21 +++++++++++++------ .../src/profile_grouped_conv_bwd_data.cpp | 18 ---------------- .../src/profile_grouped_conv_bwd_weight.cpp | 16 -------------- profiler/src/profile_grouped_conv_fwd.cpp | 20 ------------------ .../profile_grouped_conv_fwd_bias_clamp.cpp | 6 ------ .../src/profile_grouped_conv_fwd_clamp.cpp | 6 ------ test/CMakeLists.txt | 6 ++++++ 24 files changed, 177 insertions(+), 140 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a50303113d..15fdb09f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. +* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". ### Changed diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c4d79f9..acae1f5ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + # definition will be added based on the GPU target in the following section + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -106,6 +110,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_FP8 "ON") @@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") set(CK_GFX950_SUPPORT "ON") endif() +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") +else() + message(STATUS "Disabling TF32 instances") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() diff --git a/README.md b/README.md index 01d523c2ab..8a5258bab6 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb Additional cmake flags can be used to significantly speed-up the build: -* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 2ed338d08a..cab84f5c6c 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -27,6 +27,9 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -41,6 +44,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") if (GPU_TARGETS MATCHES "gfx94") @@ -67,6 +71,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() + if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") + else() + message(STATUS "Disabling TF32 instances for this target") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 306a6c2ff1..113bf99243 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,6 +55,11 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#define CK_ENABLE_TF32 "ON" +#endif +#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -85,6 +90,12 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ +#endif +#endif + #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 03e3ae88a3..89009c6d0b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index cd65a2285a..84a715b70a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in PassThrough, PassThrough, Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, "ComputeTypeA and ComputeTypeB must be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index 36980e5935..c898dbf781 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta PassThrough, PassThrough, Scale>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, " only support same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e677f6f848..3fe8fa9c5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 448a6b5d51..a0e8e46570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index acf9c9e150..64bbdf6ec5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ba2f6b921a..5089ea2c1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same!"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); @@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( @@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( @@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 46bc0d2320..d4729f4d13 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory && @@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory && @@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 90852d2945..090c99819f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory && @@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 0) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index eeaf269394..ef037526ca 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_f16") elseif(type MATCHES "fp32") set(type1 "_f32") + elseif(type MATCHES "tf32") + set(type1 "_tf32") elseif(type MATCHES "fp8") set(type1 "_f8") elseif(type MATCHES "bf16") @@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME) #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR - source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") - message(DEBUG "removing tf32 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + if(source_name MATCHES "_tf32_") + if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) + message(DEBUG "removing tf32 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() endforeach() @@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "fp32 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + message(DEBUG "tf32 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message(DEBUG "fp64 instance found!") set(add_inst 1) @@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list}) "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_bf16" OR @@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list}) list(APPEND CK_DEVICE_OTHER_INSTANCES $) endif() message(DEBUG "add_instance_directory ${subdir_path}") - endif() + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62d6e860f9..cbf763fc13 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } @@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index a18aab41a5..c4f154e180 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index c94b77dd4f..4319d849c8 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; -#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NHWGC_GKYXC_NHWGK @@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NGCDHW_GKCZYX_NGKDHW @@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 4eb12e6e19..79b9beb8c7 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 7df9fd6167..f497ee8da5 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..c221f11f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() @@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() From 04612c30ceab818cd6c03a3e833a6c6d1a21dafa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 8 Dec 2025 10:32:56 +0100 Subject: [PATCH 21/65] [CK_BUILDER] Ck Tile Grouped convolution factory (#3352) * [BUILDER] Ck Tile Grouped convolution factory * Part 2 * Fixes after rebase * Remove leftovers --- .../builder/conv_algorithm_concepts.hpp | 85 +++++++- .../ck_tile/builder/conv_algorithm_limits.hpp | 5 + .../builder/factory/conv_dispatcher.hpp | 29 ++- .../builder/factory/conv_fwd_dl_factory.hpp | 10 +- .../factory/conv_fwd_large_tensor_factory.hpp | 12 +- .../builder/factory/conv_fwd_v3_factory.hpp | 12 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 12 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 12 +- .../builder/factory/conv_tile_factory.hpp | 131 ++++++++++++ .../helpers/{ => ck}/conv_block_transfer.hpp | 0 .../helpers/{ => ck}/conv_elementwise_op.hpp | 0 .../helpers/{ => ck}/conv_tensor_layout.hpp | 0 .../helpers/{ => ck}/conv_tensor_type.hpp | 0 .../helpers/{ => ck}/conv_thread_block.hpp | 0 .../helpers/{ => ck}/conv_tuning_params.hpp | 0 .../ck_tile/conv_tile_block_transfer.hpp | 25 +++ .../ck_tile/conv_tile_elementwise_op.hpp | 62 ++++++ .../ck_tile/conv_tile_kernel_directions.hpp | 88 ++++++++ .../ck_tile/conv_tile_tensor_layout.hpp | 200 ++++++++++++++++++ .../helpers/ck_tile/conv_tile_tensor_type.hpp | 87 ++++++++ .../ck_tile/conv_tile_thread_block.hpp | 32 +++ .../ck_tile/conv_tile_tuning_params.hpp | 158 ++++++++++++++ .../builder/include/ck_tile/builder/types.hpp | 9 + experimental/builder/test/CMakeLists.txt | 31 +-- .../{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp | 0 .../conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp | 0 ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp | 0 ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp | 0 .../test/conv/{ => ck}/test_conv_traits.cpp | 0 .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 52 +++++ .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 52 +++++ .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 52 +++++ .../test/impl/conv_algorithm_types.hpp | 118 +++++++++++ .../builder/test/unit_conv_elementwise_op.cpp | 2 +- .../builder/test/unit_conv_tensor_layout.cpp | 2 +- .../builder/test/unit_conv_tensor_type.cpp | 2 +- .../builder/test/unit_conv_thread_block.cpp | 2 +- .../builder/test/unit_conv_tuning_params.cpp | 2 +- .../test/utils/ckb_conv_test_utils.hpp | 16 ++ .../test/utils/ckb_conv_tile_test_configs.hpp | 85 ++++++++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 7 +- ...ouped_convolution_backward_data_kernel.hpp | 17 +- ...ped_convolution_backward_weight_kernel.hpp | 37 ++-- .../grouped_convolution_forward_kernel.hpp | 36 ++-- .../utils/grouped_convolution_utils.hpp | 37 ++++ 55 files changed, 1431 insertions(+), 92 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_block_transfer.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_elementwise_op.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_layout.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_type.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_thread_block.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tuning_params.hpp (100%) create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_conv_traits.cpp (100%) create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ecb1ff933e..bf7e89fcaa 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileThreadBlockDescriptor = requires(T t) { + { t.tile_size.m } -> std::convertible_to; + { t.tile_size.n } -> std::convertible_to; + { t.tile_size.k } -> std::convertible_to; +}; + +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileTransferDescriptor = requires(T t) { + { t.a_scalar_per_vector } -> std::convertible_to; + { t.b_scalar_per_vector } -> std::convertible_to; + { t.c_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept TileBlockGemmDescriptor = requires(T t) { + { t.warps.m } -> std::convertible_to; + { t.warps.n } -> std::convertible_to; + { t.warps.k } -> std::convertible_to; + { t.warp_tile.m } -> std::convertible_to; + { t.warp_tile.n } -> std::convertible_to; + { t.warp_tile.k } -> std::convertible_to; + { t.double_smem_buffer } -> std::convertible_to; + { t.num_wave_groups } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies optimizations (CK Tile). +template +concept TileOptimizationsDescriptor = requires(T t) { + { t.num_groups_to_merge } -> std::convertible_to; + { t.split_image } -> std::convertible_to; + { t.explicit_gemm } -> std::convertible_to; +}; + // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this // concept. template @@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; +// Concept to check if struct specifies thread block info (CK Tile). +template +concept SpecifiesTileThreadBlock = requires { + { T::thread_block } -> TileThreadBlockDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseXdlGemm = requires { @@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. +template +concept SpecifiesTileTransfer = requires(T t) { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + { T::transfer.c_scalar_per_vector } -> std::convertible_to; +}; + // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { @@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; +// Concept to check if struct specifies block GEMM (CK Tile). template -concept SpecifiesFwdConcSpecialization = requires { +concept SpecifiesTileBlockGemm = requires { + { T::block_gemm.warps.m } -> std::convertible_to; + { T::block_gemm.warps.n } -> std::convertible_to; + { T::block_gemm.warps.k } -> std::convertible_to; + { T::block_gemm.warp_tile.m } -> std::convertible_to; + { T::block_gemm.warp_tile.n } -> std::convertible_to; + { T::block_gemm.warp_tile.k } -> std::convertible_to; + { T::block_gemm.double_smem_buffer } -> std::convertible_to; + { T::block_gemm.num_wave_groups } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept SpecifiesTileOptimizations = requires { + { T::optimizations.num_groups_to_merge } -> std::convertible_to; + { T::optimizations.split_image } -> std::convertible_to; + { T::optimizations.explicit_gemm } -> std::convertible_to; +}; + +template +concept SpecifiesTileConvSpecialization = requires { + { T::specialization } -> std::convertible_to; +}; + +template +concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 093916dac3..10a619024a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires { Value.lds_dst_scalar_per_vector > 0; }; +// Limits for input and output vector transfer (CK Tile). +template +concept TileInputOutputVectorTransferLimits = + requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; + // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 51945544b2..9a9c2235e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -59,6 +59,7 @@ #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" +#include "ck_tile/builder/factory/conv_tile_factory.hpp" namespace ck_tile::builder::factory { @@ -81,6 +82,15 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. +// CK Tile kernel +template +consteval bool IsTileAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && + SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && + SpecifiesTileOptimizations; +} + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template consteval bool IsXdlV3Algorithm() @@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; } @@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; } @@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; } @@ -120,7 +130,7 @@ template consteval bool IsDlAlgorithm() { return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; } @@ -137,10 +147,15 @@ template constexpr auto make_conv_instance() { - if constexpr(ConvDirectionIsForward) - { - using AlgoType = std::remove_const_t; + using AlgoType = std::remove_const_t; + // CK Tile supports common factory for each direction + if constexpr(IsTileAlgorithm()) + { + return typename ConvTileFactory::Instance{}; + } + else if constexpr(ConvDirectionIsForward) + { if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 0c675ac7f1..ca202aabfd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,11 +7,11 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 98e368ca61..fadf41f48a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 79955a1f44..89787cc1b3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index fcce46aea7..bb84479071 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index df7fb25168..8ec5c633ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp new file mode 100644 index 0000000000..cce95cb3f1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp" + +namespace ck_tile::builder::factory { + +// Factory for CK Tile Grouped Convolution kernels. +template +struct ConvTileFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::TileConvTensorLayouts; + using Types = internal::TileConvTensorTypes; + using Ops = internal::TileElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization(); + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations(); + static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer(); + static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(TileInputOutputVectorTransferLimits); + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + BLOCK_GEMM.double_smem_buffer, + typename GroupedConvTraitsType::template GemmLayouts::AsLayout, + typename GroupedConvTraitsType::template GemmLayouts::BsLayout, + typename GroupedConvTraitsType::template GemmLayouts::CLayout, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + BLOCK_GEMM.num_wave_groups>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + GemmShape, + GemmUniversalTraits, + BLOCK_GEMM.scheduler, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Types::EDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename internal::TilePipelineType< + BLOCK_GEMM.pipeline_version>::template GemmPipeline; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Instance = typename internal::GroupedConvolutionTileKernel::Instance; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp new file mode 100644 index 0000000000..fbeb48b045 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +struct TileScalarPerVector +{ + size_t a = 0; + size_t b = 0; + size_t c = 0; +}; + +template +constexpr TileScalarPerVector SetTileBlockTransfer() +{ + return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector, + .b = ALGORITHM.transfer.b_scalar_per_vector, + .c = ALGORITHM.transfer.c_scalar_per_vector}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp new file mode 100644 index 0000000000..45ff7d265d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOpToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Clamp; +}; + +template +consteval auto GetTileElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCKTile{}; + } + else + { + return ElementwiseOpToCKTile{}; + } +} + +template +struct TileElementwiseOps +{ + static constexpr auto input_op = GetTileElementwiseOp(); + static constexpr auto weight_op = GetTileElementwiseOp(); + static constexpr auto output_op = GetTileElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp new file mode 100644 index 0000000000..189b199ffc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct GroupedConvolutionTileKernel +{ + static_assert(false, "Unknown Direction"); +}; + +template + requires ConvDirectionIsForward +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionForwardKernel; +}; + +template + requires ConvDirectionIsBackwardData +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardDataKernel; +}; + +template + requires ConvDirectionIsBackwardWeight +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel; +}; + +template +consteval ck_tile::GroupedConvDirection SetTileConvDirection() +{ + constexpr auto direction = SIGNATURE.direction; + using ck_tile_direction = ck_tile::GroupedConvDirection; + switch(direction) + { + case ConvDirection::FORWARD: return ck_tile_direction::FORWARD; + case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA; + case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT; + default: throw "Unknown Direction"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp new file mode 100644 index 0000000000..2aaca98586 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -0,0 +1,200 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { +using ALayout = ck_tile::tensor_layout::convolution::NWGC; +template +struct LayoutToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); +}; + +// Bias layouts +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_K; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_C; +}; + +// Input 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWC; +}; + +// Input 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWC; +}; + +// Input 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWC; +}; + +// Weight 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCX; +}; + +// Weight 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCKTile() +{ + return typename LayoutToCKTile::type{}; +} + +struct EmptyAuxiliaryTileTensorLayout +{ + using type = ck_tile::tuple<>; +}; + +template +consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) +{ + return ck_tile::tuple< + decltype(TensorLayoutToCKTile())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTileTensorLayouts +{ + static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return AuxiliaryTileTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return EmptyAuxiliaryTileTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct TileConvTensorLayouts +{ + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp new file mode 100644 index 0000000000..493fbb7d9b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK Tile tensor types. +template +struct TileConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::half_t; + using AComputeType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using BComputeType = ck_tile::half_t; + using CShuffleDataType = ck_tile::half_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::half_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::bf16_t; + using AComputeType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using BComputeType = ck_tile::bf16_t; + using CShuffleDataType = ck_tile::bf16_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::bf16_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::fp8_t; + using AComputeType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using BComputeType = ck_tile::fp8_t; + using CShuffleDataType = ck_tile::fp8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::fp8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp new file mode 100644 index 0000000000..65d81a49c4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileConvBlock +{ + TileBlockMNK per_block = {}; +}; + +template +constexpr TileConvBlock SetTileThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return TileConvBlock{ + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp new file mode 100644 index 0000000000..b7df0e4d0e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockGemmMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileBlockGemmSpec +{ + TileBlockGemmMNK warps = {}; + TileBlockGemmMNK warp_tile = {}; + + bool double_smem_buffer = false; + int num_wave_groups = 1; + + ck_tile::GemmPipeline pipeline_version; + ck_tile::GemmPipelineScheduler scheduler; +}; + +struct TileOptimizations +{ + int num_groups_to_merge = 1; + bool split_image = false; + bool explicit_gemm = false; +}; + +template +consteval ck_tile::GemmPipelineScheduler SetTileScheduler() +{ + constexpr auto scheduler = ALGORITHM.block_gemm.scheduler; + using ck_tile_sched = ck_tile::GemmPipelineScheduler; + switch(scheduler) + { + case PipelineScheduler::DEFAULT: return ck_tile_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave; + case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave; + default: throw "Unknown PipelineScheduler"; + } +} + +template +struct TilePipelineType +{ + static_assert(false, "Unknown PipelineScheduler"); +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; +}; + +template +consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.block_gemm.pipeline_version; + using ck_tile_pipeline = ck_tile::GemmPipeline; + switch(version) + { + case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1; + case PipelineVersion::V2: return ck_tile_pipeline::MEMORY; + case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; + case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; + case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.specialization; + using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization; + switch(specialization) + { + case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default; + case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0; + case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_tile_conv_spec::Filter1x1Stride1Pad0; + case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +template +consteval TileBlockGemmSpec SetTileBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + constexpr bool double_smem_buffer = BG.double_smem_buffer; + constexpr int num_wave_groups = BG.num_wave_groups; + + constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion(); + constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler(); + + return TileBlockGemmSpec{ + .warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k}, + .warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k}, + .double_smem_buffer = double_smem_buffer, + .num_wave_groups = num_wave_groups, + .pipeline_version = pipeline_version, + .scheduler = scheduler}; +} + +template +consteval TileOptimizations SetTileOptimizations() +{ + constexpr auto& OPT = ALGORITHM.optimizations; + + return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, + .split_image = OPT.split_image, + .explicit_gemm = OPT.explicit_gemm}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 565bb98528..532d8a1882 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -145,6 +145,15 @@ enum struct GemmSpecialization MNKOPadding }; +// Enums for the CK Tile convolution specialization. +enum class TileConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3x3 +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a340a789de..eef1110d27 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description @@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp - conv/test_ckb_conv_fwd_1d_fp16.cpp - conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_2d_dl_fp16.cpp - conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp + conv/ck/test_ckb_conv_fwd_1d_fp16.cpp + conv/ck/test_ckb_conv_fwd_1d_bf16.cpp + conv/ck/test_ckb_conv_fwd_1d_i8.cpp + conv/ck/test_ckb_conv_fwd_2d_fp8.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_bf16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp similarity index 100% rename from experimental/builder/test/conv/test_conv_traits.cpp rename to experimental/builder/test/conv/ck/test_conv_traits.cpp diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp new file mode 100644 index 0000000000..ad31fc52bc --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_data", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp new file mode 100644 index 0000000000..47908e0e5b --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_weight", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp new file mode 100644 index 0000000000..083d9d9955 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_forward", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d89d83357f..29c7f3cdcc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,73 @@ struct LargeTensorWrapper ConvAlgorithmSpecialization::LARGE_TENSOR; }; +// Specify thread block dimensions for a GEMM (CK Tile). +struct TileThreadBlock +{ + // Size of the submatrix problem in a thread block. + MNK tile_size; +}; +static_assert(ckb::TileThreadBlockDescriptor); + +struct TileTransfer +{ + size_t a_scalar_per_vector; + size_t b_scalar_per_vector; + size_t c_scalar_per_vector; +}; +static_assert(ckb::TileTransferDescriptor); + +struct TileBlockGemm +{ + // Number of warps per each dimension. + MNK warps; + // Number of data processed per each dimension for each XDL/WMMA instruction. + MNK warp_tile; + // Double LDS buffer. + bool double_smem_buffer; + // Waves grouping (Ping-Pong scheduler). + int num_wave_groups; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; +}; +static_assert(ckb::TileBlockGemmDescriptor); + +struct TileOptimizations +{ + // Number of convolution groups processed per one workgroup + int num_groups_to_merge; + // Split image for large tensors + bool split_image; + // Explicit gemm for 1x1, stride=0, pad=0 cases + bool explicit_gemm; +}; +static_assert(ckb::TileOptimizationsDescriptor); + +struct TileConvSpecialization_ +{ + TileConvSpecialization specialization; +}; + +struct TileThreadBlock_ +{ + TileThreadBlock thread_block; +}; + +struct TileTransfer_ +{ + TileTransfer transfer; +}; + +struct TileBlockGemm_ +{ + TileBlockGemm block_gemm; +}; + +struct TileOptimizations_ +{ + TileOptimizations optimizations; +}; + // Factory template @@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components... result.transfer = t; return result; } + + template + constexpr auto with_tile_specializations(const S& s) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.specialization = s; + return result; + } + + template + constexpr auto with_tile_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_tile_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_tile_transfer(const T& t) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; + return result; + } + + template + constexpr auto with_tile_optimizations(const O& o) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.optimizations = o; + return result; + } }; // Algorithm types @@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; +using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 84a9c533f6..610edd281e 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 7764e94dc6..26df33cc8d 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "impl/conv_signature_types.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index c92b24626e..7ffd446966 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp index f829708696..ce5a772cfa 100644 --- a/experimental/builder/test/unit_conv_thread_block.cpp +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 82117c53d8..b35a1ced55 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -3,7 +3,7 @@ #include -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" namespace { diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index 508c621c2e..1acf170455 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -28,4 +28,20 @@ constexpr void run_test(const std::vector& kernel_instance_componen } } +// Common CK Tile test implementation +template +constexpr void run_ck_tile_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + std::cout << kernel_string << std::endl; + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp new file mode 100644 index 0000000000..377234dd19 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr TileTransfer FwdTileTransfer_1x1x1{ + .a_scalar_per_vector = 1, + .b_scalar_per_vector = 1, + .c_scalar_per_vector = 1, +}; + +constexpr TileTransfer FwdTileTransfer_4x4x4{ + .a_scalar_per_vector = 4, + .b_scalar_per_vector = 4, + .c_scalar_per_vector = 4, +}; + +constexpr TileTransfer FwdTileTransfer_8x8x8{ + .a_scalar_per_vector = 8, + .b_scalar_per_vector = 8, + .c_scalar_per_vector = 8, +}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index d4475e8c60..8fae704203 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK)); + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 2c6b1f3d48..e35f4ce70d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem return concat('_', "gemm_problem", concat('x', kBlockSize), concat('x', kPadM, kPadN, kPadK), - Scheduler); + Scheduler, + "NumWaveGroups", + NumWaveGroups, + "DoubleSmemBuffer", + DoubleSmemBuffer + ); // clang-format on } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index e172e732fa..46c60cb6d7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off return concat('_', "grouped_convolution_backward_data", gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, "gemm", GemmPipeline::GetName(), "epilogue", - EpiloguePipeline::GetName()); + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6ef1d84a6e..f43bfdacac 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { - constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } else { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), "merge", NumGroupsToMerge); - } + return concat('_', "grouped_convolution_backward_weight", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 72ba17c5a5..a9f3274805 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel { constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), - "merge", - NumGroupsToMerge); - } else { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 71739c9083..5b00e53af8 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -9,6 +9,13 @@ namespace ck_tile { +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview @@ -113,6 +120,36 @@ struct GroupedConvTraits using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + template + struct GemmLayouts + { + static_assert(false, "Unsupported direction."); + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutFwd; + using BsLayout = BsLayoutFwd; + using CLayout = CLayoutFwd; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdData; + using BsLayout = BsLayoutBwdData; + using CLayout = CLayoutBwdData; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdWeight; + using BsLayout = BsLayoutBwdWeight; + using CLayout = CLayoutBwdWeight; + }; + template using GroupedConvImplicitGemmTraitsFwd = TileGemmTraits; From 878b4e7f46d7e47618f4d860d71b438cb6d992fd Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 8 Dec 2025 19:20:44 +0800 Subject: [PATCH 22/65] [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287) * [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 * typo --- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 134 +++++++----------- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 81 ++++++----- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 47 +++++- 3 files changed, 141 insertions(+), 121 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index d9fb144176..1133da33ad 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel; - using TilePartitioner = remove_cvref_t; - using FlatmmPipeline = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using MXFlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using DsLayout = remove_cvref_t; using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; + static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; @@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel::PackedSize; static constexpr int BPackedSize = numeric_traits::PackedSize; - static constexpr int MXdlPack = FlatmmPipeline::MXdlPack; - static constexpr int NXdlPack = FlatmmPipeline::NXdlPack; - static constexpr int KXdlPack = FlatmmPipeline::KXdlPack; + static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack; + static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack; + static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack; static constexpr index_t NumDTensor = DsDataType::size(); @@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); + return concat('_', "mx_flatmm_gemm", gemm_prec_str, MXFlatmmPipeline::GetName()); // clang-format on } @@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); }(); - constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock; + constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; const auto& b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0, + static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( make_tuple(kFlatN, kFlatKBlocks, number{})); @@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); }(); const auto& b_flat_tensor_view = views.at(I1); @@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(d_tensor_view[i], make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }, number{}); @@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); }(); const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); const auto ds_block_window = generate_tuple( @@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel(kargs.a_ptr) + - splitk_batch_offset.a_k_split_offset / APackedSize; - const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + - splitk_batch_offset.b_k_split_offset / BPackedSize; + const auto a_ptr = static_cast(kargs.a_ptr) + + splitk_batch_offset.a_k_split_offset / APackedSize; + const auto b_flat_ptr = static_cast(kargs.b_ptr) + + splitk_batch_offset.b_k_split_offset / BPackedSize; EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel::value)) { - constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ff799cb0fc..87ae7f57d8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem; using CLayout = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + using BlockFlatmm = remove_cvref_t())>; @@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::PackedSize; - static constexpr index_t BPackedSize = numeric_traits::PackedSize; + // static constexpr index_t WG_AKPacks = WG::kK / APackedSize; + // static constexpr index_t WG_BKPacks = WG::kK / BPackedSize; static constexpr index_t MXdlPack = Problem::MXdlPack; static constexpr index_t NXdlPack = Problem::NXdlPack; static constexpr index_t KXdlPack = Problem::KXdlPack; static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; - static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; - static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; + static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow( + b_flat_dram_block_window_tmp); auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; @@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, true_type{}, false_type{}); + async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{}); }; // HEAD @@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // move B window to next flat K b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); // prefetch Scale A @@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); }); @@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 969cddf3e7..4d76ab7da2 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + using BDataType = remove_cvref_t; + constexpr index_t BPack = numeric_traits::PackedSize; static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); @@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>>{}); } + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) + { + + using BDataType = remove_cvref_t; + constexpr auto BPackedSize = numeric_traits::PackedSize; + constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; + constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); + constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; + constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; + + static_assert(std::decay_t::get_num_of_dimension() == 2); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + auto&& byte_tensor_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple( + flat_n, flat_k / flat_k_per_block, number{})), + make_tuple(make_pass_through_transform(flat_n), + make_merge_transform_v3_division_mod(make_tuple( + flat_k / flat_k_per_block, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = + make_tensor_view(byte_ptr, byte_tensor_desc); + auto&& origin_tmp = window_tmp.get_window_origin(); + return make_tile_window( + byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / BPackedSize}, + MakeMX_BFlatBytesDramTileDistribution()); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { From ca6143f0b2237a1af80ef5550f1b774fd463676d Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Mon, 8 Dec 2025 20:44:17 +0500 Subject: [PATCH 23/65] Add a workaround for a compiler issue for bwd on gfx90a and ROCm 7.1.1 (#3369) Sometimes there are not enough wait-states between v_mfma_f32... and v_accvgpr_read_b32 instructions if they are separated by s_cbranch. The workaround is to read accvgprs to vgpr before branching. --- .../block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 854e45c432..7cc424597a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -552,6 +552,15 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }); }); } +#if defined(__gfx9__) + else + { + // Workaround for a compiler issue: sometimes there are not enough wait-states + // between v_mfma_f32... and v_accvgpr_read_b32 instructions if they are separated + // by s_cbranch. + tile_elementwise_inout([](auto& x) { asm("; force move to %0" : "+v"(x)); }, s_acc); + } +#endif { bool need_perpixel_check = mask.IsEdgeTile( From fe07b5a1bff597df8a81fb227aee0ac95e06b197 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 8 Dec 2025 21:19:22 +0100 Subject: [PATCH 24/65] [CK Tile] Grouped GEMM aquant mode and non-persistent kernel (#3337) * wip: add aquant to grouped gemm quant example * fix: properly handle hot loop count in aquant pipeline * fix: add separate GemmConfig structs for AQuant, automatically select the correct one * feat: finish support for a non-persistent kernel invocation for grouped gemm quant, and add support code to example * refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic * chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants * feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel * fix: update base pipeline classes according to changes in develop branch * Revert "chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants" This reverts commit b3fd4d326d9ccb13e6902bd470bbe76fb323ba54. * feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter * chore: removed work-around for aquant bug that has been fixed * chore: fix typo in command-line parameters * fix: correct K warp tile size for gfx950 * chore: incorrect warp tile configuration on gfx942 --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 225 ++++++++++-- .../17_grouped_gemm/quant_grouped_gemm.hpp | 75 +++- .../quant_run_grouped_gemm_example.inc | 233 ++++++++---- .../kernel/grouped_gemm_quant_kernel.hpp | 109 +++++- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 49 ++- .../ck_tile/grouped_gemm_quant/CMakeLists.txt | 3 + .../test_grouped_gemm_quant.cpp | 51 +-- .../test_grouped_gemm_quant_aquant.cpp | 38 ++ .../test_grouped_gemm_quant_bquant.cpp | 11 +- .../test_grouped_gemm_quant_rowcol.cpp | 13 +- .../test_grouped_gemm_quant_tensor.cpp | 13 +- .../test_grouped_gemm_util_quant.hpp | 334 +++++++++++++++--- 12 files changed, 948 insertions(+), 206 deletions(-) create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,14 +9,190 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + template ; // Persistence + GemmConfig::Persistent>; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..0317685770 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -64,6 +64,7 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -83,10 +84,11 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; }; -template -struct GemmConfigComputeV3_2 : public GemmConfigBase +template +struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)"); + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..37832b54ba 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = + grouped_gemm(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { + return run_grouped_gemm_example_with_layouts typename GemmConfig> +template +int run_gemm_example_persistency( + std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +{ + if(persistent) + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +} + int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); if(data_type == "fp8") { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { @@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..726f678d37 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel auto karg = QuantGroupedGemmKernelArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].e_ptr), type_convert(gemm_descs[i].aq_ptr), type_convert(gemm_descs[i].bq_ptr), - gemm_descs[i].k_batch, + type_convert(gemm_descs[i].e_ptr), M, N, K, @@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel stride_b, stride_e, gemm_descs[i].stride_AQ, - gemm_descs[i].stride_BQ}; + gemm_descs[i].stride_BQ, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel else { - RunGemmWithPipelineSelection(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline @@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel } } + CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + // For persistent kernels template , diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 2bd2571993..7a7ae77730 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 551989421f..6a1a28884a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,32 +18,41 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using AQuant = std::integral_constant; using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..8dcd6d017d --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using AQuant = std::integral_constant; + +// clang-format off +using KernelTypes_AQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 4f44acf4c4..6c0ad545b7 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -20,9 +20,14 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp index 48720aeebf..cc1b32fb20 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index f59fa29ec2..e446f7b168 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 68b6735655..9941066c3e 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using AQLayout = Row; using BQLayout = Col; - static constexpr bool Persistent = true; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; - - template - static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() - { -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif - } + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; struct GroupedGemKernelParam_Mfma { @@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = - TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 32; }; struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma @@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + template + float invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B + + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + + using QuantGroupSize = ck_tile::QuantGroupShape>; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BQLayout, TransposeC, DoubleSmemBuffer, - true>; + Persistent>; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantType == ck_tile::QuantType::RowColQuant || - QuantType == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / 128; // Group quantization: BQK = K / GroupSize - if(K % 128 != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, 0, stride_BQs[i], is_row_major(BQLayout())))); + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + 0, 0, stride_AQs[i], is_row_major(AQLayout{})))); bq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); @@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if constexpr(Persistent) { // Generate kernel arguments std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); assert(gemm_descs[0].k_batch == 1); for(const auto& arg : gemm_descs) { @@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test } else { - GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + const auto stream = ck_tile::stream_config{nullptr, false, 1}; +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#else + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#endif } // Copy results back to host for validation @@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; +template +using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; + template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; From c363a98d4154c647c1a2d5331ad0d76879b84dfa Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 8 Dec 2025 21:05:56 +0000 Subject: [PATCH 25/65] [CK_TILE] Support more layouts for BQuant GEMM (#3349) * WIP: preparing to add transpose bq support * WIP: handle both row/col layout for BQ windows/tile dstr * Fix build * WIP: adding some test, debugging numerical errors * Fix all but pkint4 tests * Remove test_gemm_quant_typed.cpp again * update disabled tests * add conversion from pkint4 for b matrix * fix formatting * fix formatting * Fix tr_load and use override b datatype for clarity * fix formatting * make bquant preshuffle tests bqlayout column-major --- .../block_universal_gemm_as_bs_bquant_cr.hpp | 32 +++- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 66 ++++++-- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 14 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 22 ++- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 85 +++++++--- .../pipeline/gemm_group_quant_utils.hpp | 151 ++++++++++++------ .../gemm_block_scale/test_gemm_quant_base.hpp | 4 +- .../test_gemm_quant_bquant.cpp | 76 +++++---- .../test_gemm_quant_bquant_preshuffle.cpp | 90 +++++------ .../test_gemm_quant_fixtures.hpp | 8 +- 10 files changed, 359 insertions(+), 189 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index d97145cbc3..628e9194ae 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index dd85705cf2..203b79aec6 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -426,7 +426,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -781,7 +780,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleQuant) { - static_assert(std::is_same_v); + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -791,14 +792,35 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); + + if constexpr(std::is_same_v) + { + // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] + // Dimensions: [K/QuantGroupK, N/QuantGroupN] + // Strides: [N/QuantGroupN, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] + // Dimensions: [N/QuantGroupN, K/QuantGroupK] + // Strides: [K/QuantGroupK, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } } } else @@ -1023,10 +1045,10 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { + using QuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; @@ -1042,13 +1064,23 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } } } else diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 4cd343e640..c570d4a131 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile(), XPerTile()), + make_tuple(YPerTile{}, XPerTile{}), bq_dram_block_window_tmp.get_window_origin(), Policy::template MakeBQDramTileDistribution()); return bq_copy_dram_window; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 870326cb9d..154d068f0a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } } template @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), VecLoadSize, + BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else { + // KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups) using TileEncodingPattern = tile_distribution_encoding_pattern_bq; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::QuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 4883a30f57..2c191cc2b4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // B datatype is converted to A datatype during loading + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) /// /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative + /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: /// - /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast - /// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp /// - /// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps): + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN - /// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 /// - /// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps): + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): /// - Quantization group spans multiple warps /// - All warps share the same scale value - /// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale /// /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { + // Preshuffle only supported for ColumnMajor currently + static_assert(!(PreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); + if constexpr(PreshuffleQuant) { + // ColumnMajor only for preshuffle constexpr index_t X1 = warp_size; - constexpr index_t X0 = XPerTile / warp_size; + constexpr index_t X0 = NPerTile / warp_size; constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = YPerTile / Y1; + constexpr index_t Y0 = KPerTile / Y1; return make_static_tile_distribution( tile_distribution_encoding, @@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(YPerQ < WarpGemm::kN) + if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp - constexpr index_t YR = YPerQ; // Elements per quantization group + // N dimension needs to be partitioned the same way regardless of layout + constexpr index_t NR = 1; // No N replication needed + constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t N1 = NWarps; // Number of warps in N-dim + constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp - static_assert(Y0 * Y1 * Y2 == YPerTile, - "Y0, Y1, Y2 must cover the blocktile along Y."); + static_assert(N0 * N1 * N2 == NPerTile, + "N0, N1, N2 must cover the blocktile along N dimension."); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else if constexpr(YPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto Y1 = NWarps / YR; // Warps per unique scale - constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto N1 = NWarps / NR; // Warps per unique scale + constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1), K] - N on Y-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1)] - N on X-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else // XPerQ > WarpGemm::kN * NWarps + else // NPerQ > WarpGemm::kN * NWarps { // Case 3: Coarse-grained - quantization group spans all warps // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [N, K] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, N] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } } } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 38bd59b882..39a7c66f38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // BQLayout is always ColumnMajor for BQuant - using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + // Re-use the AQLayout for BQLayout + using BQLayout = AQLayout; using CodegenGemmTraits = ck_tile::TileGemmQuantTraits>; using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: // clang-format off using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 6cde4bded5..3a62fc091a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape>; using GroupSize2D64N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests with PreshuffleB -// Tuple format: // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7b16529aa8..bf9c7a138d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(BQLayout{}) ? BQN : BQK; // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); - // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); From c1c2e41a0387e8e76970ad86959e28963f569d54 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 9 Dec 2025 11:02:33 +0800 Subject: [PATCH 26/65] [CK_TILE] Generate random tensor values with multiple threads (#3324) --- example/ck_tile/15_fused_moe/main.cpp | 33 ++-- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 16 +- include/ck_tile/host/fill.hpp | 96 ++++++----- include/ck_tile/host/joinable_thread.hpp | 49 ++++++ test/ck_tile/utility/CMakeLists.txt | 2 + test/ck_tile/utility/test_fill.cpp | 156 ++++++++++++++++++ 6 files changed, 286 insertions(+), 66 deletions(-) create mode 100644 test/ck_tile/utility/test_fill.cpp diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index ac174379df..efb83efbd2 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -284,26 +284,25 @@ bool run(const ck_tile::ArgParser& arg_parser) } else if(init == 1) { - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(g_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(d_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sa_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sg_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sd_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sy_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}( - topk_weight_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(g_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(d_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(topk_weight_host); } else if(init == 2) { - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(g_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(d_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sa_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sg_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sd_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sy_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(topk_weight_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(g_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(d_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sa_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sg_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sd_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sy_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(topk_weight_host); } // permute weight diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 44fd12e2d9..cc2c041ed6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -71,17 +71,17 @@ int run_mx_flatmm_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); } else if(init_method == 1) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); } else { diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 12f43ebc5e..4bbf8cbf3f 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -33,59 +33,73 @@ namespace ck_tile { * @example * * // Direct usage without creating a separate variable: - * ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host_tensor); + * ck_tile::FillUniformDistribution<>{-1.f, 1.f}(a_host_tensor); */ -template +template struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; std::optional seed_{11939}; - // ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed - // across threads). - bool threaded = false; template void operator()(ForwardIter first, ForwardIter last) const { - if(threaded) - { - uint32_t num_thread = std::thread::hardware_concurrency(); - auto total = static_cast(std::distance(first, last)); - auto work_per_thread = static_cast((total + num_thread - 1) / num_thread); + if(first == last) + return; + using T_iter = std::decay_t; + static_assert(std::is_same_v || std::is_void_v, + "Iterator value type must match template type T"); + constexpr auto PackedSize = numeric_traits::PackedSize; + const auto total = static_cast(std::distance(first, last)); + const auto total_bytes = total * sizeof(T_iter); - std::vector threads(num_thread); - for(std::size_t it = 0; it < num_thread; ++it) - { - std::size_t iw_begin = it * work_per_thread; - std::size_t iw_end = std::min((it + 1) * work_per_thread, total); - auto thread_f = [this, total, iw_begin, iw_end, &first] { - if(iw_begin > total || iw_end > total) - return; - // need to make each thread unique, add an offset to current seed - std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); - }; - threads[it] = joinable_thread(thread_f); - } - } - else + // max 80 threads; at least 2MB per thread + const size_t available_cpu_cores = get_available_cpu_cores(); + const size_t num_thread = + min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); + constexpr size_t BLOCK_BYTES = 64; + constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter); + const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES); + const size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread); + + // use minstd_rand for better performance on discard() + std::minstd_rand gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + + std::vector threads; + threads.reserve(num_thread - 1); // last job run in the main thread + for(int it = num_thread - 1; it >= 0; --it) { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first, last, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); + const size_t ib_begin = it * blocks_per_thread; + const size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks); + + auto job = [=]() { + auto g_ = gen; // copy + auto d_ = dis; // copy + g_.discard(ib_begin * BLOCK_SIZE * PackedSize); + auto t_fn = [&]() { + if constexpr(PackedSize == 2) + return type_convert(fp32x2_t{d_(g_), d_(g_)}); + else + return type_convert(d_(g_)); + }; + + size_t ib = ib_begin; + for(; ib < ib_end - 1; ++ib) // full blocks + static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) { + constexpr size_t iw = iw_.value; + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }); + for(size_t iw = 0; iw < BLOCK_SIZE; ++iw) // last block + if(ib * BLOCK_SIZE + iw < total) + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }; + + if(it > 0) + threads.emplace_back(std::move(job)); + else + job(); // last job run in the main thread } } diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index bf84858ee2..b2e1fc4dac 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -3,6 +3,9 @@ #pragma once +#ifdef __linux__ +#include +#endif #include #include @@ -24,4 +27,50 @@ struct joinable_thread : std::thread this->join(); } }; + +inline unsigned int get_available_cpu_cores() +{ +#if defined(__linux__) + cpu_set_t cpu_set; + if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0) + { + unsigned int cpu_count = CPU_COUNT(&cpu_set); + if(cpu_count > 0) + return cpu_count; + } +#endif + // Fallback if sched_getaffinity unavailable or fails + return std::thread::hardware_concurrency(); +} + +class cpu_core_guard +{ +#if defined(__linux__) + cpu_set_t original_cpu_set_; + + public: + cpu_core_guard(unsigned int num_cores) : original_cpu_set_() + { + // save original cpu set + sched_getaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + + // set new cpu set + cpu_set_t new_cpu_set; + CPU_ZERO(&new_cpu_set); + for(unsigned int i = 0; i < num_cores; ++i) + { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + CPU_SET(i, &new_cpu_set); // NOLINT(old-style-cast) +#pragma clang diagnostic pop + } + sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set); + } + ~cpu_core_guard() + { + // restore original cpu set + sched_setaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + } +#endif +}; } // namespace ck_tile diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt index aa15293411..01ed83841b 100644 --- a/test/ck_tile/utility/CMakeLists.txt +++ b/test/ck_tile/utility/CMakeLists.txt @@ -3,5 +3,7 @@ message("-- Adding: test/ck_tile/utility/") +add_gtest_executable(test_fill test_fill.cpp) + # Add print tests add_subdirectory(print) diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp new file mode 100644 index 0000000000..18f42c4ad0 --- /dev/null +++ b/test/ck_tile/utility/test_fill.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host/fill.hpp" +#include "ck_tile/host/joinable_thread.hpp" +#include +#include +#include +#include + +using namespace ck_tile; + +namespace test { + +// Test fixture for FillUniformDistribution tests +template +class FillUniformDistributionTest : public ::testing::Test +{ + public: + static constexpr uint32_t seed = 42; + static constexpr float a = -5.0f; + static constexpr float b = 5.0f; +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(FillUniformDistributionTest, TestTypes); + +// Test that multiple runs with the same seed produce identical results +TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + constexpr size_t size = 1024 * 1024 * 1024 / sizeof(T); // 1G + + std::vector vec1(size); + auto start = std::chrono::high_resolution_clock::now(); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + auto end = std::chrono::high_resolution_clock::now(); + double sec = std::chrono::duration(end - start).count(); + std::cout << "Taking " << sec << " sec to fill 1GB of data of type " << typeid(T).name() + << std::endl; + + const auto cpu_cores = max(32U, get_available_cpu_cores()); + for(auto num_threads_diff : {-3, -1}) + { + cpu_core_guard cg(min(max(cpu_cores + num_threads_diff, 1U), get_available_cpu_cores())); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "First and second fill should be identical"; + } +} + +// Test consistency across different data sizes (which affects threading) +TYPED_TEST(FillUniformDistributionTest, ConsistencyAcrossSizes) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + for(size_t size : test_sizes) + { + std::vector reference(size); + std::vector test_vec(size); + + FillUniformDistribution{a, b, seed}(reference.begin(), reference.end()); + + // Run multiple times to ensure consistency + for(int run = 0; run < 3; ++run) + { + std::fill(test_vec.begin(), test_vec.end(), T{}); + FillUniformDistribution{a, b, seed}(test_vec.begin(), test_vec.end()); + + EXPECT_EQ(0, std::memcmp(reference.data(), test_vec.data(), size * sizeof(T))) + << "Mismatch for size=" << size << " run=" << run; + } + } +} + +// Test that different seeds produce different results +TYPED_TEST(FillUniformDistributionTest, CommonPrefix) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + auto longest = std::make_unique>(test_sizes[0]); + FillUniformDistribution{a, b, seed}(longest->begin(), longest->end()); + for(size_t i = 1; i < test_sizes.size(); ++i) + { + auto current = std::make_unique>(test_sizes[i]); + FillUniformDistribution{a, b, seed}(current->begin(), current->end()); + size_t min_size = std::min(longest->size(), current->size()); + EXPECT_EQ(0, std::memcmp(longest->data(), current->data(), min_size * sizeof(T))) + << "Different sizes with same seed should have the same prefix"; + if(current->size() > longest->size()) + { + longest = std::move(current); + } + } +} + +// Test edge cases +TYPED_TEST(FillUniformDistributionTest, EdgeCases) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + // Empty range + std::vector empty_vec; + EXPECT_NO_THROW((FillUniformDistribution{a, b, seed}(empty_vec.begin(), empty_vec.end()))); + + // Single element + std::vector single1(1); + std::vector single2(1); + FillUniformDistribution{a, b, seed}(single1.begin(), single1.end()); + FillUniformDistribution{a, b, seed}(single2.begin(), single2.end()); + + EXPECT_EQ(0, std::memcmp(single1.data(), single2.data(), sizeof(T))) + << "Single element should be consistent"; + + // Small sizes that might affect threading decisions + std::vector small_sizes = {2, 3, 7, 15, 16, 17, 31, 32, 33, 63, 64, 65}; + for(size_t size : small_sizes) + { + std::vector vec1(size); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "Edge case failed for size=" << size; + } +} +} // namespace test From 6f0966e1e9fca5c513d16a729237d676b583e266 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 9 Dec 2025 17:54:55 +0800 Subject: [PATCH 27/65] fix a16w4 moe bugs (#3373) * fix valid mask bug * update format --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b3b34a6da0..7104547247 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel auto fused_token = kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); if constexpr(IsInputGemm) scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); From 50ca3f83ebc08ffe8946c3668fd879e3b2043ef7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 07:10:34 -0800 Subject: [PATCH 28/65] Bump rocm-docs-core[api_reference] from 1.20.1 to 1.31.0 in /docs/sphinx (#3374) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.20.1 to 1.31.0. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/v1.31.0/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.20.1...v1.31.0) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.31.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index beedb4e867..b607daa9ff 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.20.1 +rocm-docs-core[api_reference]==1.31.0 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index e8aa02aa01..fce859cf0e 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.20.1 +rocm-docs-core[api-reference]==1.31.0 # via -r requirements.in rpds-py==0.24.0 # via From 7582c9e73fc3e580a2255988310cb25391f80162 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 07:35:32 -0800 Subject: [PATCH 29/65] Upgrade to ROCm7.1.1 as default compiler. (#3370) * upgrade to rocm7.1.1 as new default compiler * fix jenkinsfile --- Dockerfile | 6 +++--- Dockerfile.compiler | 2 +- Jenkinsfile | 8 ++++---- python/ck4inductor/__init__.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 07327442fe..973dcedcb5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=7.0.1 +ARG ROCMVERSION=7.1.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive RUN set -xe && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ - apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ +RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \ + apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \ apt update && \ apt install python3-setuptools python3-wheel -y && \ apt install rocm-dev -y diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 47bd8294b6..0e2219b7ff 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index 45fd576ab6..b8c570b936 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -288,7 +288,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 2 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -434,7 +434,7 @@ def buildDocker(install_prefix){ } catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" - retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage = docker.build("${image_name}", dockerArgs) withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } @@ -1121,8 +1121,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '7.0.1', - description: 'Specify which ROCM version to use: 7.0.1 (default).') + defaultValue: '7.1.1', + description: 'Specify which ROCM version to use: 7.1.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py index 0eee25ecaa..089a2d439b 100644 --- a/python/ck4inductor/__init__.py +++ b/python/ck4inductor/__init__.py @@ -6,7 +6,7 @@ def __version__(): import subprocess # needs to be manually updated - rocm_version = "7.0.1" + rocm_version = "7.1.1" hash_width = 6 try: hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[ From 0d8259affd4f59eb8b1143b658d83d3800270f43 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:37:13 -0800 Subject: [PATCH 30/65] temporarily disable daily builds on gfx1010 and gfx908 (#3384) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b8c570b936..3f94820095 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1095,7 +1095,7 @@ def run_pytorch_tests(Map conf=[:]){ //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true From 934ba1208ab7cfc82c20f73b14994b64c3843d2d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:39:08 -0800 Subject: [PATCH 31/65] use hipTensor from monorepo for daily builds (#3386) --- Jenkinsfile | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 3f94820095..5f03310cab 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -834,12 +834,14 @@ def Build_CK(Map conf=[:]){ if (params.hipTensor_test && arch == "gfx90a" ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash - rm -rf "${params.hipTensor_branch}".zip - rm -rf hipTensor-"${params.hipTensor_branch}" - wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip - unzip -o "${params.hipTensor_branch}".zip + rm -rf rocm-libraries + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git + cd rocm-libraries + git sparse-checkout init --cone + git sparse-checkout set projects/hiptensor + git checkout "${params.hipTensor_branch}" """ - dir("hipTensor-${params.hipTensor_branch}"){ + dir("rocm-libraries/projects/hiptensor"){ sh """#!/bin/bash mkdir -p build ls -ltr From 1aa93ef551a31405aef5c8c14e869241ba96639d Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Wed, 10 Dec 2025 10:03:28 +0800 Subject: [PATCH 32/65] [CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377) * update coherence --------- Co-authored-by: Zzz9990 --- .../core/arch/amd_buffer_addressing.hpp | 8 ++-- .../arch/amd_buffer_addressing_builtins.hpp | 8 ++-- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 46 +++++++++++++++---- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 12 +++++ ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 15 ++++-- .../moe_flatmm_pipeline_agmem_bgmem_creg.hpp | 3 ++ ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 12 +++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 13 ++++-- 8 files changed, 88 insertions(+), 29 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index ba9201135c..8830adfdd9 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1413,10 +1413,10 @@ enum struct amd_buffer_coherence_enum WAVE_NT1 = 2, GROUP_NT0 = 1, GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, + DEVICE_NT0 = 16, + DEVICE_NT1 = 18, + SYSTEM_NT0 = 17, + SYSTEM_NT1 = 19, }; template ( - b_flat_ptr, - make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + if constexpr(!FlatmmPipeline::BPreShufflePermute) + { + index_t kFlatK = + kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK) + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + index_t kFlatK = FlatmmPipeline::flatKPerWarp; + index_t kFlatN0 = (kargs.N >> 4); + index_t kFlatK0 = (kargs.K >> 7); + + auto b_tensor_view_naive = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK), + make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1), + number{}, + number<1>{}); + return transform_tensor_view( + b_tensor_view_naive, + make_tuple( + make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl), + make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } }(); // TODO: enable vector write for C in ColMajor diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 79b36adec4..e4f186dead 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } + CK_TILE_HOST static constexpr amd_buffer_coherence_enum + GetBMemNTType(index_t M, index_t N, index_t K) + { + ck_tile::ignore = N; + ck_tile::ignore = K; + if(M <= 416) + { + return ck_tile::amd_buffer_coherence_enum::WAVE_NT1; + } + return ck_tile::amd_buffer_coherence_enum::coherence_default; + } + template CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 8ec23b7570..74d82b8949 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -16,10 +16,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem { using BlockGemmShape = BlockGemmShape_; @@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + static constexpr auto BMemNTType = Problem::BMemNTType; + static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute; + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { diff --git a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp index fe6d3ec830..5681726afe 100644 --- a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp @@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + static constexpr auto BMemNTType = Problem::BMemNTType; + static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 87ae7f57d8..69e9441ae5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -16,10 +16,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct MXFlatmmPipelineProblem : FlatmmPipelineProblem { using BlockGemmShape = BlockGemmShape_; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index e35f4ce70d..46c1f69b12 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -316,10 +316,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct FlatmmPipelineProblem { using Traits = remove_cvref_t; @@ -353,6 +355,9 @@ struct FlatmmPipelineProblem static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; + static constexpr auto BMemNTType = BMemNTType_; + static constexpr bool BPreShufflePermute = BPreShufflePermute_; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off From fc22320d783a6b73798a23d8d20fb24e3a5e4040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:30:30 +0200 Subject: [PATCH 33/65] [CK_TILE] Split-K autodeduction (#3351) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First version of split-K autodeduction. * Fix circular dependency and kernel construction. * Fix tolerance calculation for bwd weight example. * Simplify kernel construction. * Fix kernel launching bug for split-K autodeduce. * Add split-K autodeduction support for the two stage example. * Fix a corner case. * Fix clang-format. * Fix clang-format for inc files. * Add missing header. * Prevent too large split-K values. * Fix formatting. * Add unit tests for IsSupportedArgument in grouped bwd conv. * clang-format. * Fix merge conflicts. * Address feedback from code review. * clang-format * Fix new tests after merge. --------- Co-authored-by: Ville Pietilä <> --- ...ed_convolution_backward_weight_invoker.hpp | 16 +- ...tion_backward_weight_two_stage_invoker.hpp | 18 +- .../grouped_convolution_utils.hpp | 6 + ...grouped_convolution_bwd_weight_example.inc | 51 ++-- include/ck_tile/host/device_prop.hpp | 18 ++ include/ck_tile/ops/grouped_convolution.hpp | 1 + ...ped_convolution_backward_weight_kernel.hpp | 88 ++++++- .../utils/split_k_utils.hpp | 81 ++++++ test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/grouped_conv/CMakeLists.txt | 7 + .../test_ck_tile_grouped_conv_bwd_weight.cpp | 249 ++++++++++++++++++ 11 files changed, 485 insertions(+), 51 deletions(-) create mode 100644 include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp create mode 100644 test/ck_tile/grouped_conv/CMakeLists.txt create mode 100644 test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0891e8c20b..afe43cd1c0 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker typename DsDataType = ck_tile::tuple<>, typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> - static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, - const ck_tile::stream_config& s) + static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) { // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker TilePartitioner, GemmPipeline, ConvEpilogue>; - auto kargs = Kernel::MakeKernelArgs(args); + const auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker } auto preprocess = [&]() { - if(args.k_batch > 1) + if(kargs.k_batch > 1) { ck_tile::hip_check_error( hipMemsetAsync(kargs.wei_ptr, @@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker } }; - return ck_tile::launch_kernel_time_mask( + const auto ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + const auto split_k = kargs.k_batch; + + return InvokerResult{ave_time, split_k}; }; if(args.k_batch == 1) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index 50c0ce4f87..9221746560 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker typename DsDataType = ck_tile::tuple<>, typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> - static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, - const ck_tile::stream_config& s) + static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) { using WorkspaceDataType = float; @@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker sizeof(WorkspaceDataType)); ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto kargs = Kernel::MakeKernelArgs(ws_args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + const auto kargs = Kernel::MakeKernelArgs(ws_args); const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); @@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker } auto preprocess = [&]() { - if(args.k_batch > 1) + if(kargs.k_batch > 1) ck_tile::hip_check_error( hipMemsetAsync(ws_args.wei_ptr, 0, @@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker s.stream_id_)); }; - return ck_tile::launch_kernel_time_mask( + const auto ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), @@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ck_tile::make_tuple(shape[1], 1), // Output Stride input_tensors, static_cast(c_ptr))); + + const auto split_k = kargs.k_batch; + + return InvokerResult{ave_time, split_k}; }; if(args.k_batch == 1) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index b687e0a660..63dd54dcae 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[]) bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } + +struct InvokerResult +{ + float ave_time; + ck_tile::index_t split_k; +}; diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index 2496a1b0d2..b0a140993a 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -14,22 +14,22 @@ template -float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, - int n_warmup, - int n_repeat) +InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, + int n_warmup, + int n_repeat) { - float ave_time = Invoker::template grouped_conv_bwd_weight( + auto res = Invoker::template grouped_conv_bwd_weight( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - return ave_time; + return res; } template (args, n_warmup, n_repeat); + auto res = invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); + const float ave_time = res.ave_time; weight_dev_buf.FromDevice(weight.data()); @@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); const float max_accumulated_value = *std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end()); + + const ck_tile::index_t split_k = res.split_k; const auto rtol_atol = calculate_rtol_atol( - GemmK, kbatch, max_accumulated_value); + GemmK, split_k, max_accumulated_value); pass = ck_tile::check_err(weight, weight_host_ref, "Error: Incorrect results!", diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index 2d7dc7dd18..e95ccfcfb4 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -70,6 +70,24 @@ inline bool is_load_tr_supported() // Check if load transpose is supported. return get_device_name() == "gfx950"; } + +inline size_t get_num_cus() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return 0; + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return 0; + } + return static_cast(props.multiProcessorCount); +} + } // namespace ck_tile #endif diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 1dd13b6246..23a72d79e9 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index f43bfdacac..c9e81d4744 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -14,6 +14,8 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" + #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" #endif @@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs input_left_pads = {static_cast(args.input_left_pads_[0])}; input_right_pads = {static_cast(args.input_right_pads_[0])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs input_right_pads = {static_cast(args.input_right_pads_[0]), static_cast(args.input_right_pads_[1])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.input_right_pads_[1]), static_cast(args.input_right_pads_[2])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel using GroupedConvBwdWeightKernelArgsSpecialized = GroupedConvBwdWeightKernelArgs; - // TODO: Enable this static constexpr bool IsSplitKSupported = true; static constexpr auto I0 = number<0>(); @@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel std::cout << "NPerBlock: " << number{} << std::endl; std::cout << "KPerBlock: " << number{} << std::endl; } - return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + + auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + + using KernelImpl = GroupedConvolutionBackwardWeightKernel; + + // Negative k_batch value: split-K autodeduction. + if(kernel_args.k_batch < 0) + { + const auto optimal_split_k = + calculate_optimal_k_batch( + kernel_args); + kernel_args.k_batch = optimal_split_k; + } + + return kernel_args; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { + if(kargs.k_batch < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "k_batch must be at least one. Ensure argument is created via MakeKernelArgs."); + } + return false; + } + + if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add) + { + if(kargs.k_batch == 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1."); + } + return false; + } + } + + if constexpr(!std::is_same_v && + !std::is_same_v) + { + // The epilogue performs atomic add related to split-K using the ODataType. + // If the type is less accurate than float, large split-K values may lead to + // accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases. + if(kargs.k_batch > 128) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "For epilogue output data type that is not float/double, we must have " + "k_batch <= 128."); + } + return false; + } + } + if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value) || - !IsSplitKSupported) + is_any_of::value)) { if(kargs.k_batch != 1) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + CK_TILE_ERROR("Conditions not met for K_batch > 1!"); } return false; } diff --git a/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp new file mode 100644 index 0000000000..072134dbe7 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp @@ -0,0 +1,81 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include + +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST index_t get_max_occupancy_for_kernel() +{ + constexpr int dynamic_smem_size = 0; + constexpr int min_blocks_per_cu = 1; + + const auto kernel_ptr = kentry; + + int max_occupancy = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size)); + + return static_cast(max_occupancy); +} + +CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size) +{ + static const index_t num_cus = get_num_cus(); + const index_t max_capacity = max_occupancy * num_cus; + + index_t k_batch = 1; + const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); + if(optimal_split > 1) + { + k_batch = optimal_split; + } + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " + << max_occupancy << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl; + } + return k_batch; +} + +template +struct ActiveWorkgroupsPerCU +{ + CK_TILE_HOST ActiveWorkgroupsPerCU() + { + max_occupancy_ = get_max_occupancy_for_kernel(); + } + index_t max_occupancy_{1}; +}; + +template +CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs) +{ + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + + const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch; + auto optimal_k_batch = + get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + const auto max_allowed_k_batch = kargs.GemmK; + optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl; + } + + return optimal_k_batch; +} + +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 6378bb8e43..197c9d6e1d 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -38,3 +38,4 @@ add_subdirectory(atomic_add_op) add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) +add_subdirectory(grouped_conv) diff --git a/test/ck_tile/grouped_conv/CMakeLists.txt b/test/ck_tile/grouped_conv/CMakeLists.txt new file mode 100644 index 0000000000..5bc10ffddd --- /dev/null +++ b/test/ck_tile/grouped_conv/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp) +endif() diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp new file mode 100644 index 0000000000..f37065f7c7 --- /dev/null +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -0,0 +1,249 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gtest/gtest.h" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" + +using namespace ck_tile; + +struct TestConvConfig +{ + static constexpr index_t VectorSizeA = 4; + static constexpr index_t VectorSizeB = 8; + static constexpr index_t VectorSizeC = 8; + + static constexpr index_t M_Tile = 128; + static constexpr index_t N_Tile = 128; + static constexpr index_t K_Tile = 32; + + static constexpr index_t M_Warp = 2; + static constexpr index_t N_Warp = 2; + static constexpr index_t K_Warp = 1; + + static constexpr index_t M_Warp_Tile = 16; + static constexpr index_t N_Warp_Tile = 16; + static constexpr index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3; + static constexpr index_t NumWaveGroups = 1; + static constexpr index_t NumGroupsToMerge = 1; + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; +}; + +// Helper to build full kernel type +template +struct BuildKernel +{ + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + using ConvTraits = GroupedConvTraits, + OutLayout, + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = + TileGemmUniversalTraits; + + using GemmPipelineProblem = + GemmPipelineProblem, + element_wise::PassThrough, + element_wise::PassThrough, + PrecType, // WeiDataType (C in bwd weight) + ConvTraits::FixedGemmParams::FixedVectorSize, + ConvTraits::VectorSizeA, + ConvTraits::VectorSizeB>; + + using UniversalGemmProblem = + UniversalGemmPipelineProblem; + + using GemmPipeline = GemmPipelineAgBgCrCompV3; + + using EpilogueProblem = CShuffleEpilogueProblem, + float, + PrecType, + typename ConvTraits::ImplicitGemmDsLayout, + typename ConvTraits::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + ConvConfig::M_Warp, + ConvConfig::N_Warp, + ConvConfig::M_Warp_Tile, + ConvConfig::N_Warp_Tile, + ConvConfig::K_Warp_Tile, + ConvTraits::FixedGemmParams::TransposeC, + MemOp, + ConvConfig::NumWaveGroups, + ConvTraits::FixedGemmParams::FixedVectorSize, + ConvTraits::VectorSizeC>; + + using Epilogue = CShuffleEpilogue; + + using type = + GroupedConvolutionBackwardWeightKernel; +}; + +// Helper to create 2D host args +static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t G, + index_t N, + index_t K, + index_t C, + index_t Y, + index_t X, + index_t Hi, + index_t Wi, + index_t stride_y, + index_t stride_x, + index_t dilation_y, + index_t dilation_x, + index_t left_pad_y, + index_t left_pad_x, + index_t right_pad_y, + index_t right_pad_x, + index_t k_batch = 1) +{ + auto conv_param = conv::ConvParam{2, + G, + N, + K, + C, + {Y, X}, + {Hi, Wi}, + {stride_y, stride_x}, + {dilation_y, dilation_x}, + {left_pad_y, left_pad_x}, + {right_pad_y, right_pad_x}}; + + return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch}; +} + +static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch) +{ + return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch); +} + +class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test +{ +}; + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, ValidKBatch) +{ + using Kernel = typename BuildKernel::type; + + auto host_args_kbatch_1 = create_2d_host_args(1); + auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_1)); + + auto host_args_kbatch_4 = create_2d_host_args(4); + auto kargs_4 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_4); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_4)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne) +{ + using Kernel = typename BuildKernel::type; + + auto host_args_kbatch_0 = create_2d_host_args(0); + auto kargs = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_0); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne) +{ + using Kernel = typename BuildKernel::type; + + // k_batch = 1 should fail with atomic_add + auto host_args_kbatch_1 = create_2d_host_args(1); + auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1)); + + // k_batch = 2 should pass + auto host_args_kbatch_2 = create_2d_host_args(2); + auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch) +{ + using Kernel = typename BuildKernel::type; + + // k_batch = 128 should pass + auto host_args_kbatch_128 = create_2d_host_args(128); + auto kargs_128 = + typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128)); + + // k_batch = 129 should fail for half_t output + auto host_args_kbatch_129 = create_2d_host_args(129); + auto kargs_129 = + typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129)); +} From b15df372553e0f80a660124f1b558d9cb276bd08 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 10 Dec 2025 23:08:41 +0800 Subject: [PATCH 34/65] fix: python 3.8 compatibility in fmha codegen (#3388) --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c00bdcea3b..edc0e049c5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -770,7 +770,7 @@ def create_kernel( class CompatibilityRuleFactory: @staticmethod - def get_rules() -> list[CompatibilityRule]: + def get_rules() -> List[CompatibilityRule]: # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: if problem_ctx.mode == "group": @@ -812,7 +812,7 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) @classmethod - def get_rules(cls) -> list[CompatibilityRule]: + def get_rules(cls) -> List[CompatibilityRule]: rules = CompatibilityRuleFactory.get_rules() def check_hdim_tile( @@ -846,7 +846,7 @@ class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): ) @classmethod - def get_rules(cls) -> list[CompatibilityRule]: + def get_rules(cls) -> List[CompatibilityRule]: rules = CompatibilityRuleFactoryGfx9.get_rules() def check_tile_pipeline( From 15ed65db35e6702593cd8ed1d603222fb11684e4 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 10 Dec 2025 12:25:23 -0800 Subject: [PATCH 35/65] Improve sequence sorting and add unit tests (#3376) Old sequence sort code was showing up on build profiles. Convert it to constexpr functions for much more efficient build-time execution. The sorting is still O(N^2), but our sequences are small enough it executes quickly. This reduced compilation time of a small convolution by more than 10% and time overall time spent in the compiler on a narrow build by %6. --- include/ck/utility/sequence.hpp | 325 ++++++--------- test/CMakeLists.txt | 1 + test/util/CMakeLists.txt | 7 + test/util/unit_sequence.cpp | 684 ++++++++++++++++++++++++++++++++ 4 files changed, 808 insertions(+), 209 deletions(-) create mode 100644 test/util/CMakeLists.txt create mode 100644 test/util/unit_sequence.cpp diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 9f97d44a4a..6e68690048 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -380,236 +380,143 @@ struct sequence_reduce }; #endif -template -struct sequence_sort_impl +// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17) +namespace sort_impl { + +// Temporary arrays to hold values during operations with capacity N and mutable size. +template +struct IndexedValueArray { - template - struct sorted_sequence_merge_impl - { - static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); - - static constexpr index_t chosen_value = - choose_left ? LeftValues::Front() : RightValues::Front(); - static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); - - using new_merged_values = decltype(MergedValues::PushBack(Number{})); - using new_merged_ids = decltype(MergedIds::PushBack(Number{})); - - using new_left_values = - typename conditional::type; - using new_left_ids = - typename conditional::type; - - using new_right_values = - typename conditional::type; - using new_right_ids = - typename conditional::type; - - using merge = sorted_sequence_merge_impl; - // this is output - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - RightValues, - RightIds, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge - { - using merge = sorted_sequence_merge_impl, - Sequence<>, - Comp>; - - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - static constexpr index_t nsize = Values::Size(); - - using split_unsorted_values = sequence_split; - using split_unsorted_ids = sequence_split; - - using left_unsorted_values = typename split_unsorted_values::left_type; - using left_unsorted_ids = typename split_unsorted_ids::left_type; - using left_sort = sequence_sort_impl; - using left_sorted_values = typename left_sort::sorted_values; - using left_sorted_ids = typename left_sort::sorted_ids; - - using right_unsorted_values = typename split_unsorted_values::right_type; - using right_unsorted_ids = typename split_unsorted_ids::right_type; - using right_sort = sequence_sort_impl; - using right_sorted_values = typename right_sort::sorted_values; - using right_sorted_ids = typename right_sort::sorted_ids; - - using merged_sorted = sorted_sequence_merge; - - using sorted_values = typename merged_sorted::merged_values; - using sorted_ids = typename merged_sorted::merged_ids; + index_t values[N > 0 ? N : 1]; + index_t ids[N > 0 ? N : 1]; + index_t size = 0; }; -template -struct sequence_sort_impl, Sequence, Compare> +template +constexpr auto make_indexed_value_array(Sequence) { - static constexpr bool choose_x = Compare{}(ValueX, ValueY); + constexpr index_t N = sizeof...(Is); + IndexedValueArray result = {{Is...}, {}, N}; + for(index_t i = 0; i < N; ++i) + { + result.ids[i] = i; + } + return result; +} - using sorted_values = - typename conditional, Sequence>::type; - using sorted_ids = typename conditional, Sequence>::type; +enum class SortField +{ + Values, + Ids }; -template -struct sequence_sort_impl, Sequence, Compare> +// Perform an insertion sort on an IndexedValueArray. +template +constexpr auto insertion_sort(IndexedValueArray arr, Compare comp) { - using sorted_values = Sequence; - using sorted_ids = Sequence; + for(index_t i = 1; i < arr.size; ++i) + { + index_t key_val = arr.values[i]; + index_t key_id = arr.ids[i]; + index_t j = i - 1; + while(j >= 0 && comp(key_val, arr.values[j])) + { + arr.values[j + 1] = arr.values[j]; + arr.ids[j + 1] = arr.ids[j]; + --j; + } + arr.values[j + 1] = key_val; + arr.ids[j + 1] = key_id; + } + return arr; +} + +// Remove duplicates from a sorted IndexedValueArray. +template +constexpr auto unique(const IndexedValueArray& sorted, Equal eq) +{ + IndexedValueArray result{}; + if constexpr(N == 0) + { + return result; + } + result.size = 1; + result.values[0] = sorted.values[0]; + result.ids[0] = sorted.ids[0]; + for(index_t i = 1; i < sorted.size; ++i) + { + if(!eq(sorted.values[i], sorted.values[i - 1])) + { + result.values[result.size] = sorted.values[i]; + result.ids[result.size] = sorted.ids[i]; + ++result.size; + } + } + return result; +} + +// Compute sorted (and optionally unique) IndexedValueArray from input Sequence. +template +constexpr auto compute_sorted(Sequence seq, Compare comp, Equal eq) +{ + auto sorted = insertion_sort(make_indexed_value_array(seq), comp); + return Unique ? unique(sorted, eq) : sorted; +} + +// Cache the sorted results to avoid recomputation. +template +struct SortedCache +{ + static constexpr auto data = compute_sorted(Seq{}, Compare{}, Equal{}); }; -template -struct sequence_sort_impl, Sequence<>, Compare> +// Build sorted value and ID sequences from cached sorted data +template +constexpr index_t get_sorted_field() { - using sorted_values = Sequence<>; - using sorted_ids = Sequence<>; + constexpr auto& data = SortedCache::data; + return (Field == SortField::Values) ? data.values[I] : data.ids[I]; +} + +template +struct SortedSequences; + +template +struct SortedSequences> +{ + using values_type = + Sequence()...>; + using ids_type = + Sequence()...>; }; +template +using sorted_sequences_t = SortedSequences< + Unique, + Seq, + Compare, + Equal, + typename arithmetic_sequence_gen<0, SortedCache::data.size, 1>:: + type>; + +using Equal = ck::math::equal; + +} // namespace sort_impl + template struct sequence_sort { - using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; - using sort = sequence_sort_impl; - - // this is output - using type = typename sort::sorted_values; - using sorted2unsorted_map = typename sort::sorted_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template struct sequence_unique_sort { - template - struct sorted_sequence_uniquify_impl - { - static constexpr index_t current_value = RemainValues::Front(); - static constexpr index_t current_id = RemainIds::Front(); - - static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); - - using new_remain_values = decltype(RemainValues::PopFront()); - using new_remain_ids = decltype(RemainIds::PopFront()); - - using new_uniquified_values = - typename conditional{})), - UniquifiedValues>::type; - - using new_uniquified_ids = - typename conditional{})), - UniquifiedIds>::type; - - using uniquify = sorted_sequence_uniquify_impl; - - // this is output - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - template - struct sorted_sequence_uniquify_impl, - Sequence<>, - UniquifiedValues, - UniquifiedIds, - Eq> - { - using uniquified_values = UniquifiedValues; - using uniquified_ids = UniquifiedIds; - }; - - template - struct sorted_sequence_uniquify - { - using uniquify = sorted_sequence_uniquify_impl, - Sequence, - Eq>; - - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - using sort = sequence_sort; - using sorted_values = typename sort::type; - using sorted_ids = typename sort::sorted2unsorted_map; - - using uniquify = sorted_sequence_uniquify; - - // this is output - using type = typename uniquify::uniquified_values; - using sorted2unsorted_map = typename uniquify::uniquified_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c221f11f46..b7db14945d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -310,3 +310,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx12") endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) +add_subdirectory(util) diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt new file mode 100644 index 0000000000..bf0a444f18 --- /dev/null +++ b/test/util/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(unit_sequence unit_sequence.cpp) +if(result EQUAL 0) + target_link_libraries(unit_sequence PRIVATE utility) +endif() diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp new file mode 100644 index 0000000000..f09fd86e06 --- /dev/null +++ b/test/util/unit_sequence.cpp @@ -0,0 +1,684 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/utility/sequence.hpp" +#include "ck/utility/functional.hpp" + +using namespace ck; + +// Test basic Sequence construction and properties +TEST(Sequence, BasicConstruction) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + EXPECT_EQ(Seq::Size(), 5); + EXPECT_EQ(Seq::mSize, 5); +} + +TEST(Sequence, EmptySequence) +{ + using Seq = Sequence<>; + EXPECT_EQ(Seq::Size(), 0); + EXPECT_EQ(Seq::mSize, 0); +} + +// Test At() method +TEST(Sequence, AtRuntime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(0), 10); + EXPECT_EQ(Seq::At(1), 20); + EXPECT_EQ(Seq::At(2), 30); + EXPECT_EQ(Seq::At(3), 40); +} + +TEST(Sequence, AtCompileTime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(Number<0>{}), 10); + EXPECT_EQ(Seq::At(Number<1>{}), 20); + EXPECT_EQ(Seq::At(Number<2>{}), 30); + EXPECT_EQ(Seq::At(Number<3>{}), 40); +} + +TEST(Sequence, OperatorBracket) +{ + constexpr auto seq = Sequence<5, 10, 15>{}; + EXPECT_EQ(seq[Number<0>{}], 5); + EXPECT_EQ(seq[Number<1>{}], 10); + EXPECT_EQ(seq[Number<2>{}], 15); +} + +// Test Front() and Back() +TEST(Sequence, FrontBack) +{ + using Seq = Sequence<100, 200, 300>; + EXPECT_EQ(Seq::Front(), 100); + EXPECT_EQ(Seq::Back(), 300); +} + +TEST(Sequence, FrontBackSingleElement) +{ + using Seq = Sequence<42>; + EXPECT_EQ(Seq::Front(), 42); + EXPECT_EQ(Seq::Back(), 42); +} + +// Test PushFront and PushBack +TEST(Sequence, PushFront) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq::PushFront(Sequence<1>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushFrontNumbers) +{ + using Seq = Sequence<3, 4>; + using Result = decltype(Seq::PushFront(Number<1>{}, Number<2>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBack) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq::PushBack(Sequence<4, 5>{})); + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBackNumbers) +{ + using Seq = Sequence<1, 2>; + using Result = decltype(Seq::PushBack(Number<3>{}, Number<4>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test PopFront and PopBack +TEST(Sequence, PopFront) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopFront()); + using Expected = Sequence<2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PopBack) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopBack()); + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test Extract +TEST(Sequence, ExtractByNumbers) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Number<0>{}, Number<2>{}, Number<4>{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ExtractBySequence) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Sequence<1, 3>{})); + using Expected = Sequence<20, 40>; + EXPECT_TRUE((is_same::value)); +} + +// Test Modify +TEST(Sequence, Modify) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::Modify(Number<2>{}, Number<99>{})); + using Expected = Sequence<1, 2, 99, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test Transform +TEST(Sequence, Transform) +{ + using Seq = Sequence<1, 2, 3, 4>; + auto double_it = [](auto x) { return 2 * x; }; + using Result = decltype(Seq::Transform(double_it)); + using Expected = Sequence<2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +// Test Reverse +TEST(Sequence, Reverse) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<5, 4, 3, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ReverseSingleElement) +{ + using Seq = Sequence<42>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenNew2Old +TEST(Sequence, ReorderGivenNew2Old) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenNew2Old(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenOld2New +TEST(Sequence, ReorderGivenOld2New) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenOld2New(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test arithmetic_sequence_gen +TEST(SequenceGen, ArithmeticSequence) +{ + using Result = typename arithmetic_sequence_gen<0, 5, 1>::type; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceWithIncrement) +{ + using Result = typename arithmetic_sequence_gen<0, 10, 2>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceNegativeIncrement) +{ + using Result = typename arithmetic_sequence_gen<10, 5, -1>::type; + using Expected = Sequence<10, 9, 8, 7, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceEmpty) +{ + using Result = typename arithmetic_sequence_gen<5, 5, 1>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test uniform_sequence_gen +TEST(SequenceGen, UniformSequence) +{ + using Result = typename uniform_sequence_gen<5, 42>::type; + using Expected = Sequence<42, 42, 42, 42, 42>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceZeroSize) +{ + using Result = typename uniform_sequence_gen<0, 42>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test make_index_sequence +TEST(SequenceGen, MakeIndexSequence) +{ + using Result = make_index_sequence<5>; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, MakeIndexSequenceZero) +{ + using Result = make_index_sequence<0>; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_merge +TEST(SequenceMerge, MergeTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeSingleSequence) +{ + using Seq = Sequence<1, 2, 3>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_split +TEST(SequenceSplit, SplitInMiddle) +{ + using Seq = Sequence<1, 2, 3, 4, 5, 6>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3>; + using ExpectedRight = Sequence<4, 5, 6>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtBeginning) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<>; + using ExpectedRight = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtEnd) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3, 4>; + using ExpectedRight = Sequence<>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_sort +TEST(SequenceSort, SortAscending) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortDescending) +{ + // Create a greater-than comparator + struct greater + { + __host__ __device__ constexpr bool operator()(index_t x, index_t y) const { return x > y; } + }; + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort::type; + using Expected = Sequence<9, 8, 5, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortAlreadySorted) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortWithDuplicates) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 1, 2, 3, 4, 5, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortEmptySequence) +{ + using Seq = Sequence<>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortSingleElement) +{ + using Seq = Sequence<42>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_unique_sort +TEST(SequenceUniqueSort, UniqueSort) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortNoDuplicates) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortAllSame) +{ + using Seq = Sequence<5, 5, 5, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<5>; + EXPECT_TRUE((is_same::value)); +} + +// Test is_valid_sequence_map +TEST(SequenceMap, ValidMap) +{ + using Map = Sequence<0, 1, 2, 3>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, ValidMapPermuted) +{ + using Map = Sequence<2, 0, 3, 1>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapDuplicate) +{ + using Map = Sequence<0, 1, 1, 3>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapMissing) +{ + using Map = Sequence<0, 1, 3, 4>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +// Test sequence_map_inverse +// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new +// position j The inverse gives us new position i came from old position inverse[i] +TEST(SequenceMapInverse, InverseMap) +{ + // Map = <2, 0, 3, 1> means: old[0]->new[2], old[1]->new[0], old[2]->new[3], old[3]->new[1] + // Inverse should be: new[0]<-old[1], new[1]<-old[3], new[2]<-old[0], new[3]<-old[2] + using Map = Sequence<2, 0, 3, 1>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +TEST(SequenceMapInverse, InverseIdentityMap) +{ + using Map = Sequence<0, 1, 2, 3>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i (same as the other test) + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +// Test sequence operators +TEST(SequenceOperators, Equality) +{ + constexpr auto seq1 = Sequence<1, 2, 3>{}; + constexpr auto seq2 = Sequence<1, 2, 3>{}; + constexpr auto seq3 = Sequence<1, 2, 4>{}; + EXPECT_TRUE(seq1 == seq2); + EXPECT_FALSE(seq1 == seq3); +} + +TEST(SequenceOperators, Addition) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(Seq1{} + Seq2{}); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Subtraction) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<1, 2, 3>; + using Result = decltype(Seq1{} - Seq2{}); + using Expected = Sequence<9, 18, 27>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Multiplication) +{ + using Seq1 = Sequence<2, 3, 4>; + using Seq2 = Sequence<5, 6, 7>; + using Result = decltype(Seq1{} * Seq2{}); + using Expected = Sequence<10, 18, 28>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Division) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<2, 4, 5>; + using Result = decltype(Seq1{} / Seq2{}); + using Expected = Sequence<5, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Modulo) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<3, 7, 8>; + using Result = decltype(Seq1{} % Seq2{}); + using Expected = Sequence<1, 6, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, AdditionWithNumber) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq{} + Number<10>{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, SubtractionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} - Number<5>{}); + using Expected = Sequence<5, 15, 25>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, MultiplicationWithNumber) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq{} * Number<3>{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, DivisionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} / Number<5>{}); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberAddition) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Number<10>{} + Seq{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberMultiplication) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Number<3>{} * Seq{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +// Test helper functions +TEST(SequenceHelpers, MergeSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = decltype(merge_sequences(Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesSingle) +{ + auto double_it = [](auto x) { return 2 * x; }; + using Seq = Sequence<1, 2, 3>; + using Result = decltype(transform_sequences(double_it, Seq{})); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesTwo) +{ + auto add = [](auto x, auto y) { return x + y; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(transform_sequences(add, Seq1{}, Seq2{})); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesThree) +{ + auto add3 = [](auto x, auto y, auto z) { return x + y + z; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Seq3 = Sequence<7, 8, 9>; + using Result = decltype(transform_sequences(add3, Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<12, 15, 18>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, ReduceOnSequence) +{ + auto add = [](auto x, auto y) { return x + y; }; + constexpr auto seq = Sequence<1, 2, 3, 4, 5>{}; + constexpr auto result = reduce_on_sequence(seq, add, Number<0>{}); + EXPECT_EQ(result, 15); +} + +TEST(SequenceHelpers, SequenceAnyOf) +{ + auto is_even = [](auto x) { return x % 2 == 0; }; + constexpr auto seq1 = Sequence<1, 3, 5, 7>{}; + constexpr auto seq2 = Sequence<1, 3, 4, 7>{}; + EXPECT_FALSE(sequence_any_of(seq1, is_even)); + EXPECT_TRUE(sequence_any_of(seq2, is_even)); +} + +TEST(SequenceHelpers, SequenceAllOf) +{ + auto is_positive = [](auto x) { return x > 0; }; + constexpr auto seq1 = Sequence<1, 2, 3, 4>{}; + constexpr auto seq2 = Sequence<1, -2, 3, 4>{}; + EXPECT_TRUE(sequence_all_of(seq1, is_positive)); + EXPECT_FALSE(sequence_all_of(seq2, is_positive)); +} + +// Test scan operations +TEST(SequenceScan, ReverseInclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<10, 9, 7, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, ReverseExclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_exclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<9, 7, 4, 0>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, InclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<1, 3, 6, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test pick and modify operations +TEST(SequencePickModify, PickElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Ids = Sequence<0, 2, 4>; + using Result = decltype(pick_sequence_elements_by_ids(Seq{}, Ids{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, PickElementsByMask) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Mask = Sequence<1, 0, 1, 0, 1>; + using Result = decltype(pick_sequence_elements_by_mask(Seq{}, Mask{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, ModifyElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Values = Sequence<99, 88>; + using Ids = Sequence<1, 3>; + using Result = decltype(modify_sequence_elements_by_ids(Seq{}, Values{}, Ids{})); + using Expected = Sequence<10, 99, 30, 88, 50>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_reduce +TEST(SequenceReduce, ReduceTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2>::type; + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceReduce, ReduceMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2, Seq3>::type; + using Expected = Sequence<9, 12>; + EXPECT_TRUE((is_same::value)); +} From 8270900d606398868e747b7f9097484ee73a4cb4 Mon Sep 17 00:00:00 2001 From: Geo Min Date: Wed, 10 Dec 2025 17:34:41 -0800 Subject: [PATCH 36/65] [ci] Bumping TheRock commit hash (#3385) * Bumping TheRock commit hash * new docker hash * Using new runner name --- .github/workflows/therock-ci-linux.yml | 4 ++-- .github/workflows/therock-ci.yml | 2 +- .github/workflows/therock-test-component.yml | 2 +- .github/workflows/therock-test-packages.yml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 86d134e456..0baa503334 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -20,7 +20,7 @@ jobs: permissions: id-token: write container: - image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:583d473f263a289222c48d4b493e2956b2354a45796f09dee6f2c8ecd4504ab6 options: -v /runner/config:/home/awsconfig/ env: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} @@ -54,7 +54,7 @@ jobs: with: repository: "ROCm/TheRock" path: "TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: Setup ccache run: | diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 40a3b0bec8..0951244f31 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -65,7 +65,7 @@ jobs: -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" - test_runs_on: "linux-mi325-1gpu-ossci-rocm" + test_runs_on: "linux-mi325-1gpu-ossci-rocm-frac" therock_ci_summary: name: TheRock CI Summary diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 27eff4fdb0..565d1d3e54 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,7 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 81632fce48..cd255a40b6 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: "Configuring CI options" env: From fbbdd36ea880aaee1eb4691f1c670492fa388647 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 10:47:19 +0400 Subject: [PATCH 37/65] docs: add notes on tile distribution and inline comments (#3297) * docs: add notes on tile distribution and inline comments * Apply suggestions from code review Co-authored-by: spolifroni-amd --------- Co-authored-by: spolifroni-amd --- .../01_naive_gemm/TILE_DISTRIBUTION.md | 312 ++++++++++++++++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 12 +- ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 2 +- .../ck_tile/01_naive_gemm/practice_gemm.cpp | 34 +- .../ck_tile/01_naive_gemm/practice_gemm.hpp | 7 +- 5 files changed, 347 insertions(+), 20 deletions(-) create mode 100644 tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md diff --git a/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md new file mode 100644 index 0000000000..275d1a1c12 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md @@ -0,0 +1,312 @@ +# Tile Distribution: Mapping Threads to Data + +## Overview + +**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks. + +## The Problem + +Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine: +- Which threads load which elements. +- How the threads are organized into warps. +- The number of times each warp repeats its pattern. +- The number of elements each thread can load in a single vector instruction. + +--- + +## Bottom-Up Construction Approach + +### Step 1: Determine K Dimension Layout + +**Start with the innermost dimension (K) for memory coalescing:** + +```cpp +constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load) +constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension +``` + +**Example (with fp16):** +- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction +- `kKPerBlock = 32` +- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension + +**Visual:** +``` +K dimension (32 elements): +Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31] + K1=8 K1=8 K1=8 K1=8 +├──────────────────────────────────────────────────────────────┤ + K0=4 threads +``` + +--- + +### Step 2: Determine M Dimension Layout + +**Now partition the M dimension hierarchically:** + +#### Level 1: Threads per Warp in M (M2) + +```cpp +constexpr index_t M2 = get_warp_size() / K0; +``` + +- Warp size = 64 threads +- K dimension already uses `K0 = 4` threads per row +- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension + +**Visual (Single Warp):** +``` + K dimension (4 threads) + ┌─────┬─────┬─────┬─────┐ + 0 │ T0 │ T1 │ T2 │ T3 │ + 1 │ T4 │ T5 │ T6 │ T7 │ + 2 │ T8 │ T9 │ T10 │ T11 │ +M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows + ...│ ... │ ... │ ... │ ... │ (M2=16) + 15 │ T60 │ T61 │ T62 │ T63 │ + └─────┴─────┴─────┴─────┘ + One Warp = 64 threads +``` + +#### Level 2: Warps per Block (M1) + +```cpp +constexpr index_t M1 = kBlockSize / get_warp_size(); +``` + +- `kBlockSize = 256` threads per block +- `M1 = 256 / 64 = 4` → We have 4 warps per block + +**Visual (4 Warps):** +``` + Warp 0 (rows 0-15) + Warp 1 (rows 16-31) + Warp 2 (rows 32-47) + Warp 3 (rows 48-63) + ↑ + M1 = 4 warps cover 64 rows total +``` + +#### Level 3: Repetitions (M0) + +```cpp +constexpr index_t M0 = kMPerBlock / (M2 * M1); +``` + +- `kMPerBlock = 256` rows to cover +- `M2 * M1 = 16 * 4 = 64` rows covered by all warps +- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times + +**Visual (Complete Block):** +``` +┌──────────────┐ +│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ... +│ (rows 0-63) │ +├──────────────┤ +│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ... +│ (rows 64-127)│ +├──────────────┤ +│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ... +│(rows 128-191)│ +├──────────────┤ +│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ... +│(rows 192-255)│ +└──────────────┘ + M0 = 4 iterations +``` + +--- + +## The Tile Distribution Encoding + +Now we can construct the distribution: + +```cpp +tile_distribution_encoding< + sequence<1>, // [1] Replication + tuple, sequence>, // [2] Hierarchy + tuple, sequence<1, 2>>, // [3] Parallelism: + tuple, sequence<2, 0>>, // [3] Parallelism + sequence<1, 2>, // [4] Yield + sequence<0, 1> // [4] Yield +> +``` + +### [1] Replication: `sequence<1>` + +Defines how many times warp patterns are replicated: +- `1` = Each warp has a unique pattern (no replication) +- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing +- `4` = All warps do the same thing + +In our case: `1` means no replication (each warp is independent). + +--- + +### [2] Hierarchy: The Multi-Level Structure + +```cpp +tuple, sequence> + └───────┬──────────┘ └──────┬────────┘ + M dimension K dimension +``` + +**Concrete values:** +- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp) +- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread) + +--- + +### [3] Parallelism: Addressing the Hierarchy + +**The key insight:** Read the tuples **vertically** to understand indexing! + +```cpp +tuple, sequence<1, 2>> +tuple, sequence<2, 0>> +``` + +#### Reading Pattern + +**Column 1 (Dimension 0 = M):** +``` +sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension) +sequence<1> +``` + +**Column 2 (Dimension 1 = K):** +``` +sequence<1, 2> +sequence<2, 0> +``` +[1,2] M2=threads/warp in M dimension +[2,0] K0=threads/warp in K dimension + +--- + +### [4] Yield Sequences: Output Ordering + +```cpp +sequence<1, 2> +sequence<0, 1> + +[1,0] means M0=repetitions/warp in M dimension +[2,1] means K1=elements/thread in K dimension +``` +--- + +## Complete Example: Thread 25 in Warp 0 + +Let's trace where **Thread 25** in **Warp 0** reads data: + +### Thread Coordinates +- Thread ID in warp: 25 +- Warp ID in block: 0 + +### Decompose Thread 25 +``` +Thread 25 in a 2D layout (M2=16, K0=4): +Row index: 25 / 4 = 6 +Col index: 25 % 4 = 1 +``` + +### M Position (Row) +``` +M0 iteration: 0 (first iteration) +M1 warp: 0 (warp 0) +M2 thread: 6 (6th row in warp) +→ M position = 0*64 + 0*16 + 6 = 6 +``` + +### K Position (Column) +``` +K0 thread: 1 (column group 1) +K1 elements: 8 (will load 8 consecutive elements) +→ K position = 1*8 + [0-7] = elements 8-15 +``` + +**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements). + +--- + +## Why This Matters + +### 1. **Memory Coalescing** +- Consecutive threads access consecutive memory → efficient global memory access +- K dimension uses K1=8 for vectorized loads + +### 2. **Warp Efficiency** +- All 64 threads in a warp are utilized +- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads + +### 3. **Scalability** +- M0 repetitions allow handling larger tiles +- Same pattern scales to different sizes + +### 4. **Register Allocation** +- Each thread knows exactly how many elements it will hold +- Compiler can allocate registers optimally + +--- + +## Summary Table + +| Parameter | Value | Meaning | +|-----------|-------|---------| +| **K1** | 8 | Elements per thread (vector width) | +| **K0** | 4 | Threads along K per row | +| **M2** | 16 | Threads along M per warp | +| **M1** | 4 | Warps per block | +| **M0** | 4 | Repetitions of warp pattern | +| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) | +| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) | +| **Elements/Thread** | 32 | M0×K1 = 4×8 | + +--- + +## Visualization: Complete Thread Block + +``` +Block Tile: 256×32 + + K dimension (32 elements) + ├─────────────────────┤ + 0 ┌──────────────────────┐ ┐ + 16 │ Warp 0 │ │ + 32 │ Warp 1 │ │ Iteration 0 + 48 │ Warp 2 │ │ (M0=0) + 64 │ Warp 3 │ ┘ + 80 ├──────────────────────┤ ┐ + 96 │ Warp 0 │ │ + 112 │ Warp 1 │ │ Iteration 1 + 128 │ Warp 2 │ │ (M0=1) + 144 │ Warp 3 │ ┘ + 160 ├──────────────────────┤ ┐ + 176 │ Warp 0 │ │ + 192 │ Warp 1 │ │ Iteration 2 + 208 │ Warp 2 │ │ (M0=2) + 224 │ Warp 3 │ ┘ + 240 ├──────────────────────┤ ┐ + 256 │ Warp 0 │ │ + │ Warp 1 │ │ Iteration 3 + │ Warp 2 │ │ (M0=3) + │ Warp 3 │ ┘ + └──────────────────────┘ + +Each warp processes 16 rows × 32 cols = 512 elements +Each iteration processes 64 rows × 32 cols = 2048 elements +Total: 4 iterations × 2048 = 8192 elements ✓ +``` + +--- + +## Key Takeaways + +1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy +2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels +3. **Replication controls redundancy**: How many warps share the same pattern +4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping + +This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping! + diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp index 2921bce8bf..a3ed982488 100644 --- a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tile_distribution_encoding, // replication + tuple, sequence>, // hierarchy + tuple, sequence<1, 2>>, // parallelism + tuple, sequence<2, 0>>, // paralleism + sequence<1, 2>, // yield + sequence<0, 1>>{}); // yield } template diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index dd72f08d99..15c1743a86 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline template CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, const BDRAMTensorView& b_dram, - CDRAMTensorView& c_dram_ref) const + CDRAMTensorView& c_dram) const { // Size of the entire problem diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp index 4f0bc13dd5..7635c9376b 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -6,7 +6,7 @@ #include "practice_gemm.hpp" #include "reference_gemm.hpp" -int main() +int main(int argc, char* argv[]) { // TODO: GemmTypeConfig using ADataType = ck_tile::half_t; @@ -14,11 +14,22 @@ int main() using CDataType = float; using AccDataType = float; - // ArgParser - ck_tile::index_t M = 512; - ck_tile::index_t N = 256; - ck_tile::index_t K = 64; - ck_tile::index_t verification = 1; + // Setup simple argument parser for M, N, K + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "512", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "64", "k dimension") + .insert("v", "1", "verification: 0=off, 1=on"); + + auto result = arg_parser.parse(argc, argv); + if(!result) + return -1; + + // Get problem dimensions from command line + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t verification = arg_parser.get_int("v"); ck_tile::index_t stride_a = K; ck_tile::index_t stride_b = K; @@ -61,9 +72,6 @@ int main() ck_tile::DeviceMem c_device(c_host); // TODO: BlockTileConfig - // constexpr ck_tile::index_t warpSize = 64; - constexpr ck_tile::index_t kBlockSize = 256; - using BlockTile = ck_tile::sequence<256, 128, 32>; using WaveTile = ck_tile::sequence<16, 16, 16>; @@ -77,11 +85,13 @@ int main() ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); - std::cout << "kGridSize: " << kGridSize << std::endl; + std::cout << "Total number of thread blocks: " << kGridSize << std::endl; constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU - std::cout << "kBlockSize: " << kBlockSize << std::endl; - std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + // Block size is now derived from the shape configuration + constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize; + std::cout << "Number of threads per block: " << kBlockSize << std::endl; + std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl; using gemm_kernel = ck_tile::PracticeGemmKernel; diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp index 850e6ae3b3..91d7fae90c 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -24,6 +24,10 @@ struct PracticeGemmShape static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + // Thread block configuration + static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront) + static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads) + CK_TILE_HOST static std::string GetName() { // clang-format off @@ -40,7 +44,8 @@ struct PracticeGemmKernel using Problem = remove_cvref_t; using Policy = remove_cvref_t; - static constexpr index_t kBlockSize = 256; + // Derive block size from the shape configuration + static constexpr index_t kBlockSize = Problem::Shape::kBlockSize; CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, const typename Problem::BDataType* p_b, From 6d25525adc2344d5b62b12b9ffddee50f89cd0ff Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 10:50:43 +0400 Subject: [PATCH 38/65] feat(precommit-hooks): add check for correct copyright header (#3302) * chore(copyright): update copyright header for left files * feat(copyright): add copyright check to precommit hooks * chore(copyright): update copyright header for include/ck_tile directory * chore(copyright): update copyright header for example directory * chore(copyright): update copyright header for .github directory * refactor: copyright_check script with better if else handling * chore(copyright): update compyright header for remaining files * feat: add script to automate copyright addition --- .github/scripts/therock_configure_ci.py | 3 + .pre-commit-config.yaml | 12 +- include/ck_tile/core.hpp | 3 +- include/ck_tile/host.hpp | 3 +- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 3 +- include/ck_tile/ops/batched_contraction.hpp | 3 +- include/ck_tile/ops/batched_transpose.hpp | 3 +- include/ck_tile/ops/common.hpp | 3 +- include/ck_tile/ops/elementwise.hpp | 3 +- include/ck_tile/ops/epilogue.hpp | 3 +- include/ck_tile/ops/flatmm.hpp | 3 +- include/ck_tile/ops/fmha.hpp | 3 +- include/ck_tile/ops/fused_moe.hpp | 3 +- include/ck_tile/ops/gemm.hpp | 3 +- include/ck_tile/ops/gemm_quant.hpp | 3 +- include/ck_tile/ops/grouped_convolution.hpp | 3 +- include/ck_tile/ops/image_to_column.hpp | 3 +- include/ck_tile/ops/layernorm2d.hpp | 3 +- include/ck_tile/ops/norm_reduce.hpp | 3 +- include/ck_tile/ops/permute.hpp | 3 +- include/ck_tile/ops/pooling.hpp | 3 +- include/ck_tile/ops/reduce.hpp | 3 +- include/ck_tile/ops/rmsnorm2d.hpp | 3 +- include/ck_tile/ops/smoothquant.hpp | 3 +- include/ck_tile/ops/softmax.hpp | 3 +- include/ck_tile/ops/topk.hpp | 3 +- include/ck_tile/ops/topk_softmax.hpp | 3 +- include/ck_tile/remod.py | 5 +- script/check_copyright_year.sh | 70 ++++- script/update_amd_copyright_headers.py | 295 ++++++++++++++++++ .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 4 +- test/ck_tile/core/arch/test_arch.cpp | 4 +- tile_engine/include/utility/validation.hpp | 2 +- tile_engine/ops/gemm_streamk/CMakeLists.txt | 3 + .../gemm_streamk/gemm_streamk_benchmark.hpp | 2 +- .../gemm_streamk_benchmark_single.cpp | 2 +- .../ops/gemm_streamk/gemm_streamk_common.hpp | 2 +- .../gemm_streamk_instance_builder.py | 3 + .../gemm_streamk/gemm_streamk_profiler.hpp | 2 +- .../gemm_streamk_validation_utils.py | 2 +- 40 files changed, 408 insertions(+), 78 deletions(-) create mode 100644 script/update_amd_copyright_headers.py diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 860b6bf875..c892941fc6 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import fnmatch import json import os diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04ebc6b45a..71e7ccdb81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,12 @@ repos: )$ - repo: local hooks: - # - id: copyright-year-checker - # name: copyright-year-checker - # entry: script/check_copyright_year.sh - # verbose: false - # language: script - # types: [c++] + - id: copyright-header-checker + name: Check copyright headers + entry: script/check_copyright_year.sh + verbose: false + language: script + types_or: [c++, python, shell, cmake] - id: remove-exec-bit name: Remove executable bit from non-executable files entry: script/remove_exec_bit.sh diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 5c05e9b6ee..d28d29a0ef 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core/algorithm/cluster_descriptor.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c769e3e247..b543fd84e9 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/host/arg_parser.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 6c0972e10a..00234b20cf 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 2232ec1261..45fa52e505 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 5822d7b91b..b23e45c233 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index eff2d625b3..94243e674f 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 7f2303932e..5752703ab6 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index ec5a8ef445..555402b53a 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 7ef2fd5433..2d3a819e80 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5b87a821c9..20714397c9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 71721f3408..e6802e82dc 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ec2d2488c8..d518a15b7e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3e16d937cb..7dc5b40286 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 23a72d79e9..6743e46613 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 2307b05190..1d33ebf39d 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 9ce22137bf..ebb20aebf4 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index aa074b7f9f..469a98c256 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 46512c57fe..88a3d8a137 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 084b498203..3e44122afa 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/pooling/kernel/pool_kernel.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index d628e9c945..57f3f3c80a 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/reduce/block/block_reduce.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 00afcf4aed..ad23a708b7 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 1aa14c69e1..13372f3289 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index d559dc15e2..9cf3e08319 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 040c6b8ddc..090ad0919f 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d9657a9764..7afce1708b 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp" diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index aeec7bd471..51f3941233 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,7 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -from datetime import datetime import pathlib from pathlib import Path import subprocess @@ -13,8 +12,8 @@ OPS = "ops" OPS_COMMON = "common" # common header will be duplicated into ops/* other module IGNORED_DIRS = ["utility", "ref"] -HEADER_COMMON = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n +HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT """ diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh index 1b63c6b711..48c050c76b 100755 --- a/script/check_copyright_year.sh +++ b/script/check_copyright_year.sh @@ -2,18 +2,70 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +# This script checks if files have the correct copyright header template. +# It supports .hpp, .cpp, .inc, .py, .sh, and .cmake files. +# +# Usage: ./check_copyright_year.sh ... -current_year=$(date +%Y) exit_code=0 -for file in $@; do - if grep -q "Copyright (c)" $file - then - if ! grep -q "Copyright (c).*$current_year" $file - then - echo "ERROR: File $file has a copyright notice without the current year ($current_year)." - exit_code=1 - fi +# Expected copyright header lines (without comment characters) +COPYRIGHT_LINE="Copyright (c) Advanced Micro Devices, Inc., or its affiliates." +SPDX_LINE="SPDX-License-Identifier: MIT" + +check_file() { + local file=$1 + local basename="${file##*/}" + local ext="${file##*.}" + local comment_char + + # Determine comment character based on filename or extension + if [[ "$basename" == "CMakeLists.txt" ]]; then + comment_char="#" + else + case "$ext" in + cpp|hpp|inc) + comment_char="//" + ;; + py|sh|cmake) + comment_char="#" + ;; + *) + # Skip files with unsupported extensions + return 0 + ;; + esac + fi + + # Build expected header patterns + expected_copyright="$comment_char $COPYRIGHT_LINE" + expected_spdx="$comment_char $SPDX_LINE" + + # Check if file contains both required lines + if ! grep -qF "$expected_copyright" "$file"; then + echo "ERROR: File $file is missing the correct copyright header line." + echo " Expected: $expected_copyright" + return 1 + fi + + if ! grep -qF "$expected_spdx" "$file"; then + echo "ERROR: File $file is missing the correct SPDX license identifier line." + echo " Expected: $expected_spdx" + return 1 + fi + + return 0 +} + +# Process each file provided as argument +for file in "$@"; do + # Skip if file doesn't exist or is a directory + if [[ ! -f "$file" ]]; then + continue + fi + + if ! check_file "$file"; then + exit_code=1 fi done diff --git a/script/update_amd_copyright_headers.py b/script/update_amd_copyright_headers.py new file mode 100644 index 0000000000..489b774e97 --- /dev/null +++ b/script/update_amd_copyright_headers.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Purpose: + Normalize and enforce AMD two-line copyright + SPDX headers across files. + +Target files: + - C/C++-style: .cpp, .hpp, .inc -> uses "//" comment style + - Hash-style: .py, .cmake, .sh, and CMakeLists.txt -> uses "#" style + +Header formats inserted (top of file, followed by exactly one blank line): + C/C++ : + // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. + // SPDX-License-Identifier: MIT + + Hash : + + +Shebang special case (hash-style only): + - If line 1 starts with "#!", keep shebang, then a blank line, then the + two hash-style header lines, then a blank line. + +Removal rules: + - Remove any comment lines (anywhere in file) containing the keywords + "copyright" or "spdx" (case-insensitive). Blank lines are preserved. + - Remove long-form MIT license block comment when: + a) The file starts with the block (absolute top), OR + b) The block appears immediately after the AMD header position + (i.e., when remainder at insertion point begins with "/*" and + the first content line is "* The MIT License (MIT)"). + +Blank-line normalization: + - Enforce exactly ONE blank line immediately after the AMD header. + (Drop only the leading blank lines at the insertion point before + re-inserting the header.) + - Do not change blank lines between other non-copyright comments. + +Preservation: + - Preserve original newline style: CRLF (\r\n) vs LF (\n). + - Preserve UTF-8 BOM if present. + - Do not modify non-comment code lines. + +Idempotency: + - Running this script multiple times does not further modify files. +""" + +from __future__ import annotations +import re +import sys +from pathlib import Path +from typing import List, Tuple + +AMD_CPP_HEADER_TEXT = [ + "// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "// SPDX-License-Identifier: MIT", +] +AMD_HASH_HEADER_TEXT = [ + "# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "# SPDX-License-Identifier: MIT", +] + +CPP_EXTS = {".cpp", ".hpp", ".inc"} +HASH_EXTS = {".py", ".cmake", ".sh"} + +# --- Encoding helpers ------------------------------------------------------- + + +def has_bom(raw: bytes) -> bool: + return raw.startswith(b"\xef\xbb\xbf") + + +def decode_text(raw: bytes) -> str: + return raw.decode("utf-8-sig", errors="replace") + + +def encode_text(text: str, bom: bool) -> bytes: + data = text.encode("utf-8") + return (b"\xef\xbb\xbf" + data) if bom else data + + +# --- Newline detection ------------------------------------------------------ + + +def detect_newline_sequence(raw: bytes) -> str: + if b"\r\n" in raw: + return "\r\n" + elif b"\n" in raw: + return "\n" + else: + return "\n" + + +# --- Utilities -------------------------------------------------------------- + + +def is_comment_line(line: str, style: str) -> bool: + stripped = line.lstrip() + if style == "cpp": + return ( + stripped.startswith("//") + or stripped.startswith("/*") + or stripped.startswith("*") + or stripped.startswith("*/") + ) + elif style == "hash": + return stripped.startswith("#") + return False + + +def has_keywords(line: str) -> bool: + lower_line = line.lower() + return ("copyright" in lower_line) or ("spdx" in lower_line) + + +# --- MIT License banner detection ------------------------------ +MIT_C_FIRST_LINE_RE = re.compile(r"^\s*\*\s*The MIT License \(MIT\)") +MIT_HASH_FIRST_LINE_RE = re.compile(r"^\s*#\s*The MIT License \(MIT\)") + + +def remove_top_mit_block(lines: List[str]) -> Tuple[List[str], bool]: + """ + Unified MIT banner removal at the top of 'lines'. + Supports: + - C-style block starting with '/*' and ending with '*/'; removes only if + a line within the block matches MIT_C_FIRST_LINE_RE. + - Hash-style banner: contiguous top run of lines starting with '#'; + removes only if any line in that run matches MIT_HASH_FIRST_LINE_RE. + Returns (new_lines, removed_flag). Preserves EOLs. + """ + if not lines: + return lines, False + + first = lines[0].lstrip() + + # C-style block + if first.startswith("/*"): + end_idx, saw_mit = None, False + for i, line in enumerate(lines[1:], 1): + if not saw_mit and MIT_C_FIRST_LINE_RE.match(line): + saw_mit = True + s = line.lstrip() + if s.startswith("*/") or s.rstrip().endswith("*/"): + end_idx = i + 1 + break + if end_idx is not None and saw_mit: + return lines[end_idx:], True + return lines, False + + # Hash-style contiguous banner + if first.startswith("#"): + end_idx, saw_mit = 0, False + for i, line in enumerate(lines): + if line.lstrip().startswith("#"): + if not saw_mit and MIT_HASH_FIRST_LINE_RE.match(line): + saw_mit = True + end_idx = i + 1 + else: + break + if saw_mit: + return lines[end_idx:], True + return lines, False + + return lines, False + + +# --- Removal + normalization helpers --------------------------------------- + + +def remove_keyword_comment_lines_globally(lines: List[str], style: str) -> List[str]: + """Remove comment lines containing keywords anywhere in the file. + **Do not** remove blank lines; preserve all other lines as-is.""" + out: List[str] = [] + for line in lines: + if is_comment_line(line, style) and has_keywords(line): + continue + out.append(line) + return out + + +def drop_leading_blank_lines(lines: List[str]) -> List[str]: + """Drop only the leading blank lines at the start of the given list.""" + i = 0 + while i < len(lines) and lines[i].strip() == "": + i += 1 + return lines[i:] + + +# --- Header builder --------------------------------------------------------- + + +def build_header_lines(style: str, nl: str) -> List[str]: + base = AMD_CPP_HEADER_TEXT if style == "cpp" else AMD_HASH_HEADER_TEXT + return [base[0] + nl, base[1] + nl, nl] # header + exactly one blank + + +# --- Main transforms -------------------------------------------------------- + + +def process_cpp(text: str, nl: str) -> str: + lines = text.splitlines(True) + + # Remove MIT block if it is at the *absolute* top + lines, _ = remove_top_mit_block(lines) + + # Remove keyworded comment lines globally (blank lines preserved) + lines = remove_keyword_comment_lines_globally(lines, style="cpp") + + # Normalize insertion point and remove MIT block if it appears *after header* + lines = drop_leading_blank_lines(lines) + lines, _ = remove_top_mit_block(lines) + + # Prepend AMD header (guarantee exactly one blank after) + return "".join(build_header_lines("cpp", nl) + lines) + + +def process_hash(text: str, nl: str) -> str: + lines = text.splitlines(True) + if not lines: + return "".join(build_header_lines("hash", nl)) + + shebang = lines[0].startswith("#!") + + if shebang: + remainder = remove_keyword_comment_lines_globally(lines[1:], style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + new_top = [lines[0], nl] + build_header_lines("hash", nl) + return "".join(new_top + remainder) + else: + remainder = remove_keyword_comment_lines_globally(lines, style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + return "".join(build_header_lines("hash", nl) + remainder) + + +# --- File processing & CLI -------------------------------------------------- + + +def process_file(path: Path) -> bool: + name = path.name + suffix = path.suffix.lower() + if suffix in CPP_EXTS: + style = "cpp" + elif suffix in HASH_EXTS or name == "CMakeLists.txt": + style = "hash" + else: + return False + + raw = path.read_bytes() + bom = has_bom(raw) + nl = detect_newline_sequence(raw) + text = decode_text(raw) + + updated = process_cpp(text, nl) if style == "cpp" else process_hash(text, nl) + if updated != text: + path.write_bytes(encode_text(updated, bom)) + return True + return False + + +def main(argv: List[str]) -> int: + if len(argv) < 2: + print(__doc__) + return 2 + changed = 0 + skipped = 0 + errors: List[str] = [] + for arg in argv[1:]: + p = Path(arg) + try: + if not p.exists(): + errors.append(f"Not found: {p}") + continue + if p.is_dir(): + errors.append(f"Is a directory (pass specific files): {p}") + continue + if process_file(p): + changed += 1 + print(f"Updated: {p}") + else: + skipped += 1 + print(f"Skipped (no change needed or unsupported type): {p}") + except Exception as e: + errors.append(f"Error processing {p}: {e}") + print(f"\nSummary: {changed} updated, {skipped} skipped, {len(errors)} errors") + for msg in errors: + print(f" - {msg}") + return 0 if not errors else 1 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 4121e199e2..c7093e3477 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include diff --git a/test/ck_tile/core/arch/test_arch.cpp b/test/ck_tile/core/arch/test_arch.cpp index 2d553c1595..f015d3ce0a 100644 --- a/test/ck_tile/core/arch/test_arch.cpp +++ b/test/ck_tile/core/arch/test_arch.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include "ck_tile/core/arch/arch.hpp" diff --git a/tile_engine/include/utility/validation.hpp b/tile_engine/include/utility/validation.hpp index dc57e6cc6a..f10f37fbaa 100644 --- a/tile_engine/include/utility/validation.hpp +++ b/tile_engine/include/utility/validation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt index acfd78edc5..c692a6d247 100644 --- a/tile_engine/ops/gemm_streamk/CMakeLists.txt +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index fa8a019be5..45beb0acce 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 5e88dc486a..9dbba04082 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 15a3c91964..2708ac2e56 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 6aebc54564..2225619fad 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index 256e0b9ca4..0541116522 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index 2288d7752f..bef3cdfe85 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. From d66e5f667c9d36b9c4ad8fa0cae7dd48ec9d5ebb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:50:00 +0200 Subject: [PATCH 39/65] [CK_BUILDER] Improve CK Builder and CK Builder tests (#3382) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove stale documentation. * Add placeholder for conv algorithm design description. Add link to conv factory description. * Improve testing transfer parameters. * Python script to check the block tilings. * Improve tests and conv types serialization. * Change representation of boolean values from 1/0 to true/false in instance strings. * Change representation of boolean values from 1/0 to true/false in conv algorithm types. * Test code improvements. * Improve covn descriptions tests. * Improve conv signature definition in conv fwd builder tests. * clang-format. * Remove obsolete script. * Revert StaticAssertTypeEq changes in conv layout tests. * Remove obsolete using declaration. --------- Co-authored-by: Ville Pietilä <> --- .../builder/include/ck_tile/builder/README.md | 30 +- .../factory/helpers/ck/conv_tensor_type.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 10 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 10 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 10 +- .../builder/include/ck_tile/builder/types.hpp | 326 ++++--- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 23 +- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 21 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 21 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 40 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 30 +- .../conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 41 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 21 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 21 +- .../test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 21 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 41 +- .../conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 24 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 24 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 24 +- .../builder/test/conv/ck/test_conv_traits.cpp | 52 +- .../test/impl/conv_signature_types.hpp | 4 +- ..._grouped_convolution_forward_convscale.cpp | 864 +++++++++--------- ...grouped_convolution_forward_dynamic_op.cpp | 48 +- ...rouped_convolution_forward_scaleadd_ab.cpp | 96 +- ...olution_forward_scaleadd_scaleadd_relu.cpp | 96 +- .../builder/test/test_conv_description.cpp | 48 +- .../builder/test/test_fwd_instance_traits.cpp | 12 +- .../test_instance_string_fwd_grp_conv.cpp | 4 +- ...tance_string_fwd_grp_conv_large_tensor.cpp | 4 +- .../test_instance_string_fwd_grp_conv_v3.cpp | 4 +- .../builder/test/test_testing_utils.cpp | 4 +- .../builder/test/unit_conv_tensor_layout.cpp | 278 +++--- .../test/utils/conv_algorithm_type_utils.hpp | 346 +++++++ 33 files changed, 1568 insertions(+), 1042 deletions(-) create mode 100644 experimental/builder/test/utils/conv_algorithm_type_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index a0522a50d6..8075e33220 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -4,14 +4,16 @@ This directory contains the builder framework for Composable Kernel, which provi ## Table of Contents -- [Convolution Signature Design](#convolution-signature-design) +- [Convolution Signature](#convolution-signature) - [Overview](#overview) - [Architecture](#architecture) - [Core Components](#core-components) - [Concepts and Validation](#concepts-and-validation) +- [Convolution Algorithm](#convolution-algorithm) +- [Convolution Factory](#convolution-factory) --- -## Convolution Signature Design +## Convolution Signature ### Overview @@ -220,25 +222,9 @@ Several fields in the signature are optional: This design follows the principle of "make the common case simple, the complex case possible." -#### Union-Based Layout Representation +## Convolution Algorithm -The `ConvLayout` type uses unions to support dimension-agnostic code: +## Convolution Factory -```cpp -struct ConvLayout { - union { - ConvInputLayout _input_layout; - ConvWeightLayout _weight_layout; - ConvOutputLayout _output_layout; - ConvAuxiliaryTensorLayout _aux_tensor_layout; - }; - // ... constructors for each type -}; -``` - -This allows: -- Single type to represent all layout variants -- Type-safe construction through overloaded constructors -- Compile-time enforcement of valid combinations through concepts - ---- +Convolution factory builds the instance based on the convolution signature and convolution algorithm. +The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md). diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 81de2140f2..c819e11d00 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -65,17 +65,19 @@ consteval auto GetTensorDataAndComputeTypes() constexpr auto data_type = Config.data_type; constexpr auto compute_type = Config.compute_type; - if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED) + using enum DataType; + + if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); } - else if constexpr(data_type == DataType::UNDEFINDED) + else if constexpr(data_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); } - else if constexpr(compute_type == DataType::UNDEFINDED) + else if constexpr(compute_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); @@ -91,7 +93,7 @@ template consteval auto GetTensorAccumulationType() { constexpr auto data_type = SignatureAccDataType; - if constexpr(data_type == DataType::UNDEFINDED) + if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE) { return ConvertDataTypeToCK(); } @@ -105,7 +107,7 @@ template consteval auto GetAuxiliaryTensorDataTypeValue() { constexpr auto data_type = Config.data_type; - if constexpr(data_type == DataType::UNDEFINDED) + if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE) { return ConvertDataTypeToCK(); } diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 126be93f01..f5f3df3159 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -316,7 +316,7 @@ struct InstanceTraits; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "NGCW,GKXC,EmptyTuple,NGKW", "PassThrough,PassThrough,Scale", "Filter1x1Stride1Pad0", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index e8cd8fb136..6802e0caf8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -29,11 +34,13 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + expected_transfer_parameters, "NWGC,GKXC,EmptyTuple,NWGK", "PassThrough,PassThrough,PassThrough", "MNKPadding", - "64,64,32,32", "Default"}); } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 014e221101..14463bbc17 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -14,13 +15,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .data_type = DataType::I8, - .accumulation_data_type = DataType::INT32, - .input = {.config = {.layout = TensorLayout::GNWC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::GNWK}}}; + .direction = FORWARD, + .data_type = I8, + .accumulation_data_type = INT32, + .input = {.config = {.layout = GNWC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} @@ -31,8 +36,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", - "128,64,64,64", + expected_transfer_parameters, "GNWC,GKXC,EmptyTuple,GNWK", "PassThrough,PassThrough,PassThrough", "Default"}); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index b98e28c45a..4a5618a6b1 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Default", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", @@ -43,13 +50,17 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -60,7 +71,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + expected_transfer_parameters, "Filter3x3", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index bc4a5e1047..e3dc261fe3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,19 +13,22 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + using enum ck_tile::builder::ElementwiseOperation; + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}}, - .output = ConvolutionTensor{ - .config = {.layout = TensorLayout::NHWGK}, - .operation = TensorOperation<>{.elementwise_operation = - ElementwiseOperation::SCALEADD_SCALEADD_RELU} - .with_auxiliary_operand_configs()}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC, .data_type = BF16}}, + .output = ConvolutionTensor{ + .config = {.layout = NHWGK}, + .operation = TensorOperation<>{.elementwise_operation = SCALEADD_SCALEADD_RELU} + .with_auxiliary_operand_configs()}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -35,10 +39,12 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + expected_transfer_parameters, "NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK", "PassThrough,PassThrough,ScaleAddScaleAddRelu", - "64,64,32,32", "MNKPadding", "Default"}); } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 7af1448403..9bea834ef9 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -10,13 +11,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -27,8 +32,10 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", - "256,128,128,16", + expected_transfer_parameters, "Default", "MNKPadding", "GNHWC,GKYXC,EmptyTuple,GNHWK", @@ -38,13 +45,17 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -56,8 +67,10 @@ TEST(FwdConvInstances, .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", - "256,128,128,16", + expected_transfer_parameters, "Filter1x1Pad0", "MNKPadding", "GNHWC,GKYXC,EmptyTuple,GNHWK", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 7b522403d3..bba0128810 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v3", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 615d098c7c..79ee4915e8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP32, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKCYX}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKCYX}}, + .output = {.config = {.layout = NGKHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,128,128,32", + expected_transfer_parameters, "Filter1x1Stride1Pad0", "Intrawave", "v4", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 4dd9e2beef..3e3d7e8c2b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP8, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = FP8, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", - "256,256,128,32", + expected_transfer_parameters, "Default", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 8fe58dbe82..3019c57a18 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -30,8 +35,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "256,256,128,32", + expected_transfer_parameters, "Default", "GNHWC,GKYXC,EmptyTuple,GNHWK", "PassThrough,PassThrough,PassThrough", @@ -42,13 +49,17 @@ TEST( FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -61,8 +72,10 @@ TEST( .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "128,128,128,32", + expected_transfer_parameters, "Filter1x1Pad0", "GNHWC,GKYXC,EmptyTuple,GNHWK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 2df76ab3e0..3f9bdfb972 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNDHWC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::GNDHWK}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNDHWC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = GNDHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -30,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Default", "Intrawave", "v3", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index ad626d9a15..11c8172533 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -31,8 +35,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,128,128,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v4", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 85974ace5d..33c01c8ac4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP32, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCDHW}}, - .weight = {.config = {.layout = TensorLayout::GKCZYX}}, - .output = {.config = {.layout = TensorLayout::NGKDHW}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -31,8 +35,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v1", diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index a6a7694703..d052aba548 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -12,6 +12,12 @@ namespace { +using ck_tile::builder::ConvDirection; +using ck_tile::builder::DataType; +using ck_tile::builder::ElementwiseOperation; +using ck_tile::builder::PipelineScheduler; +using ck_tile::builder::PipelineVersion; +using ck_tile::builder::TensorLayout; using ::testing::ElementsAre; // Test fixture for ConvTraits tests @@ -84,15 +90,13 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); @@ -145,8 +149,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); - EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); + EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle @@ -214,15 +218,13 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); @@ -300,15 +302,13 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index ef87981c3d..f046289057 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -14,8 +14,8 @@ struct TensorConfig { TensorLayout layout; // Optional data types, override the type defined in the signature if provided. - DataType data_type{DataType::UNDEFINDED}; - DataType compute_type{DataType::UNDEFINDED}; + DataType data_type{DataType::UNDEFINED_DATA_TYPE}; + DataType compute_type{DataType::UNDEFINED_DATA_TYPE}; }; template diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp index 3b3b0fa7e1..60af599551 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp @@ -76,54 +76,54 @@ struct F8_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -141,54 +141,54 @@ struct F8_BF8_comb1_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>" // clang-format on }; }; @@ -206,54 +206,54 @@ struct F8_BF8_comb2_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>" // clang-format on }; }; @@ -271,54 +271,54 @@ struct F8_BF8_comb3_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>" // clang-format on }; }; @@ -336,54 +336,54 @@ struct F8_float_CombConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -401,54 +401,54 @@ struct F8_ConvScaleRelu constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -466,54 +466,54 @@ struct F8_CombConvScaleRelu constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -531,54 +531,54 @@ struct F8_ConvScaleAdd constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -596,54 +596,54 @@ struct F8_ConvInvscale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp index 2e06ebc74c..6aa2f57db2 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp @@ -85,9 +85,9 @@ struct DyOp_F32_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -98,9 +98,9 @@ struct DyOp_F32_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -111,9 +111,9 @@ struct DyOp_F16_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -124,9 +124,9 @@ struct DyOp_F16_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -137,9 +137,9 @@ struct DyOp_BF16_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -150,9 +150,9 @@ struct DyOp_BF16_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -163,9 +163,9 @@ struct DyOp_INT8_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; @@ -176,9 +176,9 @@ struct DyOp_INT8_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp index 56843b214f..918642c266 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp @@ -53,18 +53,18 @@ struct F32 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -75,18 +75,18 @@ struct F16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -97,18 +97,18 @@ struct BF16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -119,18 +119,18 @@ struct S8 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp index a833a1fe87..74f5f5e231 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp @@ -54,18 +54,18 @@ struct F32 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -76,18 +76,18 @@ struct F16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -98,18 +98,18 @@ struct BF16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -120,18 +120,18 @@ struct S8 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 689577fb3b..ace9ce0239 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -30,8 +30,8 @@ static_assert(!ckb::TensorOperatorDescriptor); struct TensorConfig { ckb::TensorLayout layout; - ckb::DataType data_type{ckb::DataType::UNDEFINDED}; - ckb::DataType compute_type{ckb::DataType::UNDEFINDED}; + ckb::DataType data_type{ckb::DataType::UNDEFINED_DATA_TYPE}; + ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE}; }; struct ConvTensorSimple @@ -55,39 +55,49 @@ struct ConvTensorWithInvalidOp // This includes dimensionality, direction, data layout, and data type. struct ConvSignature { + using enum ckb::DataType; + using enum ckb::TensorLayout; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}}; - ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}}; - ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}}; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ConvTensorSimple input = {.config = {GNHWC}}; + ConvTensorSimple weight = {.config = {GKYXC}}; + ConvTensorSimple output = {.config = {GNHWK}}; }; static_assert(ckb::ConvSignatureDescriptor); // Compile time tests for concepts struct ConvSignatureWithOptionalParams { + using enum ckb::DataType; + using enum ckb::TensorLayout; + using enum ckb::ConvDirection; + using enum ckb::ElementwiseOperation; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ckb::ConvDirection direction = FORWARD; ConvTensorWithOp input = { - .config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16}, + .config = {GNHWC, FP16}, }; - ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}}; - ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16}, - .operation = {ckb::ElementwiseOperation::SCALE}}; + ConvTensorWithOp weight = {.config = {GKYXC, FP16}}; + ConvTensorWithOp output = {.config = {GNHWK, FP16}, .operation = {SCALE}}; }; static_assert(ckb::ConvSignatureDescriptor); struct ConvSignatureWithInvalidOptionalParams { + using enum ckb::DataType; + using enum ckb::TensorLayout; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}}; - ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}}; - ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}}; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ConvTensorWithInvalidOp input = {.config = {GNHWC}}; + ConvTensorWithInvalidOp weight = {.config = {GKYXC}}; + ConvTensorWithInvalidOp output = {.config = {GNHWK}}; }; static_assert(!ckb::ConvSignatureDescriptor); diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 396533cef4..6dd2a4eada 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -262,14 +262,14 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths @@ -377,14 +377,14 @@ TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths @@ -492,14 +492,14 @@ TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp index 9929f276a7..35f3db1469 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" ",2" // ABlockTransferSrcVectorDim ",1" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",1" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp index aecce25f1d..26b50bea6d 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten ",2" // ABlockTransferSrcVectorDim ",1" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",1" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp index 7eeaec8e25..604667dd10 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",0" // ABlockLdsExtraM + ",false" // ABlockLdsExtraM ",Seq(8,32,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",0" // BBlockLdsExtraN + ",false" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_testing_utils.cpp b/experimental/builder/test/test_testing_utils.cpp index 694bec4c20..dd65f3f327 100644 --- a/experimental/builder/test/test_testing_utils.cpp +++ b/experimental/builder/test/test_testing_utils.cpp @@ -34,8 +34,8 @@ TEST(InstanceSet, FromFactory) const auto* el = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16," "fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Default,MNKPadding,1,128," - "128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2)," - "Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>"; + "128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1)," + "Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>"; EXPECT_THAT(instances.instances, testing::Contains(el)); } diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 26df33cc8d..ce31f41933 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -9,27 +9,34 @@ namespace { -namespace ckb = ::ck_tile::builder; -using ::ck_tile::builder::DataType; -using ::ck_tile::builder::ElementwiseOperation; -using ::ck_tile::builder::TensorLayout; -using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts; -using ::ck_tile::builder::factory::internal::ConvTensorLayouts; -using ::ck_tile::builder::factory::internal::LayoutToCK; +namespace ckb = ck_tile::builder; +using ck_tile::builder::DataType; +using ck_tile::builder::ElementwiseOperation; +using ck_tile::builder::TensorLayout; +using ck_tile::builder::factory::internal::AuxiliaryTensorLayouts; +using ck_tile::builder::factory::internal::ConvTensorLayouts; +using ck_tile::builder::factory::internal::LayoutToCK; +using ck_tile::builder::test::ConvolutionTensor; +using ck_tile::builder::test::ConvSignature; +using ck_tile::builder::test::TensorConfig; +using ck_tile::builder::test::TensorOperation; -using namespace ::ck_tile::builder::test; -using enum ::ck_tile::builder::ConvDirection; +namespace enums { +using enum ck_tile::builder::ConvDirection; +using enum ck_tile::builder::TensorLayout; +using enum ck_tile::builder::DataType; +} // namespace enums TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -41,14 +48,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCW}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NGKW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; using TensorLayouts = ConvTensorLayouts; @@ -60,14 +67,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNWC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::GNWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNWC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = GNWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -79,14 +86,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCW}}, - .weight = {.config = {.layout = TensorLayout::GKCX}}, - .output = {.config = {.layout = TensorLayout::NGKW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKCX}}, + .output = {.config = {.layout = NGKW}}}; using TensorLayouts = ConvTensorLayouts; @@ -98,14 +105,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -117,14 +124,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -136,14 +143,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -155,14 +162,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKCYX}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKCYX}}, + .output = {.config = {.layout = NGKHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -174,14 +181,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCDHW}}, - .weight = {.config = {.layout = TensorLayout::GKCZYX}}, - .output = {.config = {.layout = TensorLayout::NGKDHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -193,14 +200,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -212,14 +219,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNDHWC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::GNDHWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = GNDHWC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = GNDHWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -261,8 +268,10 @@ struct MockAuxiliaryTensorConfig TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -273,6 +282,8 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; @@ -285,8 +296,10 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + MockAuxiliaryTensorConfig{.layout = G_C_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -297,9 +310,11 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) { + using namespace enums; + static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + MockAuxiliaryTensorConfig{.layout = GC}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -311,10 +326,12 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}, + MockAuxiliaryTensorConfig{.layout = GC}, + MockAuxiliaryTensorConfig{.layout = G_C_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -327,8 +344,10 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -339,8 +358,10 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution) TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + MockAuxiliaryTensorConfig{.layout = GC}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -351,7 +372,8 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -359,9 +381,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NGKHW}, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -377,7 +399,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -385,9 +408,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .direction = FORWARD, .data_type = DataType::BF16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -403,8 +426,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = + TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -412,9 +436,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; @@ -431,7 +455,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -439,9 +464,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .direction = FORWARD, .data_type = DataType::FP32, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -457,7 +482,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -465,9 +491,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp new file mode 100644 index 0000000000..e4db149a98 --- /dev/null +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -0,0 +1,346 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "../impl/conv_algorithm_types.hpp" +#include +#include + +namespace ck_tile::builder::test { + +namespace ckb = ck_tile::builder; + +// Helper function to convert arrays to Seq(...) format +template +std::string array_to_seq(const std::array& arr) +{ + std::ostringstream oss; + oss << "Seq("; + for(size_t i = 0; i < N; ++i) + { + if(i > 0) + oss << ","; + oss << arr[i]; + } + oss << ")"; + return oss.str(); +} + +// Base template - will cause compilation error for unsupported types +template +std::string to_string(T) +{ + static_assert(sizeof(T) == 0, "Unsupported type"); + return ""; +} + +// Template specializations for enum types + +template <> +inline std::string to_string(PipelineVersion t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(PipelineScheduler t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(ConvFwdSpecialization t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(GemmSpecialization t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +// Template specializations for struct types + +template <> +inline std::string to_string>(MNK t) +{ + return array_to_seq(std::array{t.m, t.n, t.k}); +} + +template <> +inline std::string to_string(ThreadBlock t) +{ + std::ostringstream oss; + oss << t.block_size << "," << t.tile_size.m << "," << t.tile_size.n << "," << t.tile_size.k; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << "," + << t.m_xdl_per_wave << "," << t.n_xdl_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseWmmaGemm t) +{ + std::ostringstream oss; + oss << t.k1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," << t.m_wmma_per_wave << "," + << t.n_wmma_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(BlockGemm t) +{ + std::ostringstream oss; + oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); + return oss.str(); +} + +template <> +inline std::string to_string(BlockTransfer t) +{ + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); +} + +template <> +inline std::string to_string(ThreadCluster t) +{ + return array_to_seq( + std::array{t.m_block, t.m_wave_per_xdl, t.n_block, t.n_wave_per_xdl}); +} + +template <> +inline std::string to_string(LdsTransfer t) +{ + std::ostringstream oss; + oss << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector + << "," << (t.lds_padding ? "true" : "false") << "," + << (t.is_direct_load ? "true" : "false"); + return oss.str(); +} + +template <> +inline std::string to_string(AccessOrder t) +{ + return array_to_seq(t.order); +} + +template <> +inline std::string to_string(TransferAB t) +{ + std::ostringstream oss; + oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," + << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," + << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector + << "," << (t.lds_transfer.lds_padding ? "true" : "false"); + return oss.str(); +} + +template <> +inline std::string to_string(TransferC t) +{ + std::ostringstream oss; + oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," + << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + return oss.str(); +} + +template <> +inline std::string to_string(TransferABC t) +{ + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); +} + +template <> +inline std::string to_string(DlThreadConfig t) +{ + std::ostringstream oss; + oss << t.k1 << "," << t.m1_per_thread << "," << t.n1_per_thread << "," << t.k_per_thread; + return oss.str(); +} + +template <> +inline std::string to_string(DlThreadCluster t) +{ + std::ostringstream oss; + oss << array_to_seq(t.m1_xs) << "," << array_to_seq(t.n1_xs); + return oss.str(); +} + +template <> +inline std::string to_string(DlBlockTransfer t) +{ + std::ostringstream oss; + oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) + << "," << array_to_seq(t.thread_cluster_arrange_order) << "," + << array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths) + << "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << "," + << array_to_seq(t.dst_vector_tensor_lengths); + return oss.str(); +} + +template <> +inline std::string to_string(DlEpilogue t) +{ + std::ostringstream oss; + oss << array_to_seq(t.src_dst_access_order) << "," << t.src_dst_vector_dim << "," + << t.dst_scalar_per_vector; + return oss.str(); +} + +template <> +inline std::string to_string(DlBlockTransferAB t) +{ + return to_string(t.block_transfer); +} + +template <> +inline std::string to_string(DlBlockTransferC t) +{ + return to_string(t.epilogue); +} + +template <> +inline std::string to_string(DlTransferABC t) +{ + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); +} + +// Template specializations for factory wrapper types + +template <> +inline std::string to_string(ThreadBlock_ t) +{ + return to_string(t.thread_block); +} + +template <> +inline std::string to_string(XdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(WmmaGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(Transfer_ t) +{ + return to_string(t.transfer); +} + +template <> +inline std::string to_string(ConvSpecialization_ t) +{ + std::ostringstream oss; + oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization); + return oss.str(); +} + +template <> +inline std::string to_string(Prefetch_ t) +{ + std::ostringstream oss; + oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << "," + << to_string(t.loop_scheduler); + return oss.str(); +} + +template <> +inline std::string to_string(BlockGemm_ t) +{ + return to_string(t.block_gemm); +} + +template <> +inline std::string to_string(DlThreadConfig_ t) +{ + return to_string(t.thread_config); +} + +template <> +inline std::string to_string(DlThreadCluster_ t) +{ + return to_string(t.thread_cluster); +} + +template <> +inline std::string to_string(DlTransfer_ t) +{ + return to_string(t.transfer); +} + +// Template specializations for algorithm types + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t) +{ + return to_string(t.base_algorithm); +} + +} // namespace ck_tile::builder::test From ce99cab6056d1ffef5acb6f4ad7ede87a46a3cfc Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:06:20 +0100 Subject: [PATCH 40/65] Wmma support for gemm_ab_scale (#3314) * Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header --- .../65_gemm_multiply_multiply/CMakeLists.txt | 2 + ...mm_multiply_multiply_wmma_fp8_ab_scale.cpp | 345 +++++++++++++ ...ltiply_wmma_fp8_blockscale_bpreshuffle.cpp | 357 +++++++++++++ .../blockwise_gemm_pipeline_wmmaops_base.hpp | 146 ++++-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 468 +++++++++++++++++- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 345 ++++++++++++- .../device_gemm_multiple_d_ab_scale.hpp | 347 +++++++++++++ ..._batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 11 +- ...m_multiple_d_wmma_cshuffle_v3_ab_scale.hpp | 362 ++++++++++++++ ...ltiple_d_wmma_cshuffle_v3_b_preshuffle.hpp | 308 +----------- ...mma_cshuffle_v3_blockscale_bpreshuffle.hpp | 360 ++++++++++++++ ...mm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp | 102 +++- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 10 +- .../device_gemm_wmma_cshuffle_v3_common.hpp | 200 ++++++-- .../gridwise_ab_transfer_thread_tiles.hpp | 10 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 6 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 7 +- ...idwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} | 393 +++++++++++---- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 47 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 74 ++- .../gpu/gemm_ab_scale.hpp | 394 ++++++++++++++- .../gpu/gemm_blockscale_wp.hpp | 147 ++++++ .../gpu/CMakeLists.txt | 12 +- .../gpu/gemm_ab_scale/CMakeLists.txt | 21 +- ...e_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp | 79 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp | 80 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 95 ++++ ...k_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ .../gpu/gemm_blockscale_wp/CMakeLists.txt | 5 +- ...p_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 77 +++ ...k_mn_128_128_128_comp_default_instance.cpp | 38 ++ ...nk_mn_128_128_128_mem_default_instance.cpp | 38 ++ .../profiler/profile_gemm_ab_scale_impl.hpp | 6 +- .../profile_gemm_blockscale_wp_impl.hpp | 2 +- test/CMakeLists.txt | 1 + test/gemm_ab_scale/CMakeLists.txt | 9 + test/gemm_ab_scale/test_gemm_ab_scale.cpp | 236 +++++++++ .../gemm_ab_scale/test_gemm_ab_scale_util.hpp | 102 ++++ test/gemm_blockscale_wp/CMakeLists.txt | 4 +- ...p8.cpp => test_gemm_blockscale_wp_fp8.cpp} | 0 51 files changed, 5144 insertions(+), 552 deletions(-) create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp => gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} (58%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp create mode 100644 test/gemm_ab_scale/CMakeLists.txt create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale.cpp create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale_util.hpp rename test/gemm_blockscale_wp/{test_gemm_blockscale_wp_xdl_fp8.cpp => test_gemm_blockscale_wp_fp8.cpp} (100%) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index abfbe115fb..944a8f96bf 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -77,3 +77,5 @@ example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCAL add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp16_bpreshuffle gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp8_bpreshuffle gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_ab_scale gemm_multiply_multiply_wmma_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp new file mode 100644 index 0000000000..0fb7a70781 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp @@ -0,0 +1,345 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AM, + A0Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(static_cast(a0_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + static_cast(a1_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = .0; + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0, 50, 100}); + + int pass = 0; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + return pass; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..ba95724d3f --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,357 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr int KPack = 16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, + S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + int NPerWmma = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerWmma); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f24a1eb3bc..f831c0f6cf 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base } }; - template - struct BScale + typename ThreadDesc> + struct ABScale { - __device__ BScale(GridDesc b_scale_grid_desc_, - ThreadCopy b_scale_thread_copy_, - GridBuffer b_scale_grid_buf_) - : b_scale_thread_copy(b_scale_thread_copy_), - b_scale_grid_desc(b_scale_grid_desc_), - b_scale_grid_buf(b_scale_grid_buf_) {}; + __device__ ABScale(GridDesc scale_grid_desc_, + ThreadCopy scale_thread_copy_, + GridBuffer scale_grid_buf_) + : scale_thread_copy(scale_thread_copy_), + scale_grid_desc(scale_grid_desc_), + scale_grid_buf(scale_grid_buf_) {}; - static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{}); static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; - static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + static constexpr index_t num_slice_mn = ScaleSliceSizeMN; + static constexpr index_t num_slice_k = ScaleSliceSizeK; + static constexpr index_t reg_size_per_wmma = RegSizePerWmma; - static constexpr auto b_scale_thread_copy_step = - make_tuple(make_multi_index(NWaves * NPerWmma, 0), - make_multi_index(-NPerBlock, 0), - make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + static constexpr auto scale_thread_desc = ThreadDesc{}; + + static constexpr auto scale_thread_copy_step = + make_tuple(make_multi_index(ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, + ScaleSliceSizeK)); template __device__ void GlobalLoad(bool cond) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, Number<0>{}), - b_scale_thread_bufs(Number{})); + static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) { + scale_thread_copy.Run(scale_grid_desc, + scale_grid_buf, + scale_thread_desc, + make_tuple(m0 * Number{}, Number<0>{}), + scale_thread_bufs(Number{})); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<0>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<0>{})); }); if(cond) { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<2>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<2>{})); } else { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<1>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<1>{})); } } - ThreadCopy b_scale_thread_copy; - GridDesc b_scale_grid_desc; - GridBuffer b_scale_grid_buf; - StaticallyIndexedArray{}> b_scale_thread_bufs; + ThreadCopy scale_thread_copy; + GridDesc scale_grid_desc; + GridBuffer scale_grid_buf; + StaticallyIndexedArray{}> scale_thread_bufs; + }; + + template + struct CScale + { + __device__ CScale() {} + + static constexpr auto reg_size_per_wmma = + ck::is_same_v && ck::is_same_v + ? 1 + : wmma_gemm.GetRegSizePerWmma(); + static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, + Number{}, + Number{})); + using CScaleThreadDesc = decltype(c_scale_thread_desc); + static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + using ThreadStaticBuffer = decltype(make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize())); + + __device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct) + { + using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc); + using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_bufs(I0)(Number{}) = + a_scale_struct.scale_thread_bufs(I0)[Number{}] * + b_scale_struct.scale_thread_bufs(I0)[Number{}]; + }); + }); + }); + } + + __device__ void Clear() + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + } + + template + __device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf) + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple( + k_index, + (m_index * num_scale_m_block / MRepeat) % num_scale_m_block + + (Number{}) % + AScaleStruct::reg_size_per_wmma, + (n_index * num_scale_n_block / NRepeat) % num_scale_n_block)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_bufs(I0)[Number{}]); + }); + } + + StaticallyIndexedArray{}> c_scale_thread_bufs; + StaticBufferTupleOfVector + c_thread_buf_per_scale; }; __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 0f62aee0a8..3b12e7feb0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, @@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + Base::a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + Base::b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto blockwise_gemm_func = [&]() { + // Local load + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + Base::a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + Base::b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + block_sync_lds(); + blockwise_gemm_func(); + + block_sync_lds(); + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + blockwise_gemm_func(); + } + } + protected: // A[MRepeat, I1, I1, KPack] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + struct KLoopParams + { + static constexpr auto KRepeatNoScale = 1; + static constexpr auto NumScaleKBlock = + Number{}; + static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock; + }; + + template <> + struct KLoopParams + { + static constexpr index_t KRepeatNoScale = KRepeatPerCluster; + static constexpr index_t NumScaleKBlock = 1; + static constexpr index_t KRepeatPerNumScaleKBlock = 1; + }; + template + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc&, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer&, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + auto gemm_core_func = [&](auto reg_buf) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[reg_buf] + [Number{}, + I0, + I0, + n0, + I0, + k_index, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + auto a_local_prefetch_func = [&]() { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + }); + }; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + __builtin_amdgcn_sched_barrier(0); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + a_local_prefetch_func(); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + + gemm_core_func(wmma_reg_buf); + + block_sync_lds(); + + // loop prefetch copy + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + gemm_core_func(I0); + + block_sync_lds(); + + // tail Local Prefetch A1 + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + __builtin_amdgcn_sched_barrier(0); + + gemm_core_func(I1); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + gemm_core_func(I0); + } + } + protected: static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 08c765dd0a..b8d451363e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3; using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; using Base::A_K1; using Base::A_KRow; @@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}], b_thread_desc_, @@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); // Local prefill 1 @@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + } + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + + auto local_load_func = [&]() { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + }; + + local_load_func(); + + __builtin_amdgcn_sched_barrier(0); + + // Main body, perform when at least 3 loops exist. + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 2)); + } + + // Pre-tail, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // No RunRead or MoveSrcSliceWindow here, already finished them all! + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + + // Tail, always perform. + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + protected: using Base::a_thread_copy_; using Base::a_thread_desc_; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp index 52a915de52..23b5178e3d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator virtual int GetPreShuffleParameters() = 0; }; +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where +// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and +// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is +/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires +// an additional parameter KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is +/// expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_ABScale and +/// DeviceGemmMultipleD_ABScaleSplitK is that +/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_ABScaleSplitKWrapper + : public DeviceGemmMultipleD_ABScale +{ + + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index 7752b334ed..ee1ddc494d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) p_bs_grid_shift, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale }; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< ALayout, BLayout, Tuple<>, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale CElementwiseOperation, GemmSpec, BlockSize, + 0, // ScaleBlockM ScaleBlockN, ScaleBlockK, MPerBlock, @@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale std::array{StrideB_}, std::array{}, // StrideDs_ StrideC_, + 0, // StrideScaleA StrideScaleB_, + nullptr, p_b_scale_grid_, k_batch_, a_element_op_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp new file mode 100644 index 0000000000..81a5d35e7c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp @@ -0,0 +1,362 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + : public DeviceGemmMultipleD_ABScaleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + void SetKBatch(BaseArgument* base_arg, int KBatch) const override + { + auto& arg = *dynamic_cast(base_arg); + arg.KBatch = KBatch; + arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const BScaleDataType* p_a_scale, + const BScaleDataType* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch = 1) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + p_c, + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + p_a_scale, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_ABScale_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); - const index_t k_id = blockIdx.z * num_k_per_block; - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args, k_id); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - -} // namespace ck - namespace ck { namespace tensor_operation { namespace device { @@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, - ComputeTypeB>; + ComputeTypeB, + true>; // IsBPreshuffle // Invoker - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - std::array size_as_buffers; - size_as_buffers[Number<0>{}] = - a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - - std::array size_bs_buffers; - size_bs_buffers[Number<0>{}] = - b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( - arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); - - std::array size_ds_buffers; - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - size_ds_buffers[i] = - ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); - }); - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - DsDataType> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - size_ds_buffers); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.M * arg_.N * sizeof(EDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, - 0, - arg.M * arg.N * sizeof(EDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - // ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is - // currently implemented in such a way that all SrcScalarPerVectors must be the same, so - // if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the - // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot - // be odd. - constexpr bool AtomicsImplementationExists = - !(std::is_same_v || std::is_same_v || - std::is_same_v) || - (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) - { - return false; - } return DeviceGemmCommon::IsSupportedArgument(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp new file mode 100644 index 0000000000..1b1a1fcc6c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -0,0 +1,360 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using AScaleLayout = tensor_layout::gemm::ColumnMajor; + using BScaleLayout = BLayout; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true, + AScaleLayout, + BScaleLayout>; + + using Argument = typename GridwiseGemm::Argument; + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; // IsBPreshuffle + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else { const auto kernel = kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } } } @@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 { auto& arg = *dynamic_cast(base_arg); arg.KBatch = KBatch; + if(get_warp_size() == 64) + { + arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch); + } + else + { + arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch); + } } static constexpr bool IsValidCompilationParameter() @@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + if(!ck::is_xdl_wmma_supported()) { return false; @@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return Argument{static_cast(p_a), static_cast(p_b), p_ds, @@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, @@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return std::make_unique(static_cast(p_a), static_cast(p_b), p_ds, @@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index e824fcc9dd..491f3a7dac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, p_b_scale, KBatch, a_element_op, @@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, // p_a_scale static_cast(p_b_scale), KBatch, a_element_op, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 6706365fb7..e96ec58cba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -38,7 +38,8 @@ template + typename ComputeTypeB, + bool IsBPreShuffled = false> struct DeviceGemm_Wmma_CShuffleV3_Common { @@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common if(has_main_k_block_loop) { // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + if constexpr(IsBPreShuffled) { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } } - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); } } else { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else { const auto kernel = kernel_gemm_wmma_cshuffle_v3; Run(kernel); } } - else + } + } + else + { + if constexpr(IsBPreShuffled) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } } } } @@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common return false; } + if constexpr(IsBPreShuffled) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 4526eb3186..69f8f44390 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -388,11 +388,11 @@ struct ABTransferThreadTiles // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 return transform_tensor_descriptor( BlockDesc{}, - make_tuple( - make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index f58f67dc6b..121ca258be 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 c_thread_buf.Clear(); // Empty BScale struct for the blockwise pipeline. - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using ABScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = ABScale{}; + auto b_scale_struct = ABScale{}; /*******************************************************************************/ // @@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 b0_block_buf, b0_block_slice_copy_step, acc0_thread_buf, + a_scale_struct, b_scale_struct, KBlockMainLoop, 1); // num_k_block_per_scale diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index e55ac807c5..fea0102337 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using Scale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = Scale{}; + auto b_scale_struct = Scale{}; const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, epilogue_args, k_id); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp similarity index 58% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index 8684731c96..ac5b7dd0c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -23,6 +23,7 @@ template -struct GridwiseGemm_wmma_cshuffle_v3_b_scale + BlockGemmPipelineScheduler BlkGemmPipeSched, + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename ComputeTypeA, + typename ComputeTypeB, + bool PermuteA, + bool PermuteB, + bool IsBPreShuffled = false, + typename AScaleLayout = ALayout, + typename BScaleLayout = BLayout> +struct GridwiseGemm_wmma_cshuffle_v3_ab_scale : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, @@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< @@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true>; using Base::I0; @@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, index_t KBatch_) : M{M_}, @@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, + StrideScaleA{StrideScaleA_}, StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, @@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)} + NBlock{CalculateNBlock(N_)}, + Kt{K_} { } @@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale }); std::cout << " }, "; } - std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead - << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 - << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" - << std::endl; + std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", " + << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs; std::array StrideDs; index_t StrideE; + index_t StrideScaleA; index_t StrideScaleB; index_t KBatch; index_t MPadded; @@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t BK0; index_t MBlock; index_t NBlock; + index_t Kt; }; // Argument @@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, + const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, @@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs_, StrideDs_, StrideE_, + StrideScaleA_, StrideScaleB_, k_batch_}, p_as_grid{}, p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, + p_a_scale_grid{p_a_scale_grid_}, p_b_scale_grid{p_b_scale_grid_}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, @@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale DsGridPointer p_ds_grid; EDataType* p_e_grid; + const AScaleType* p_a_scale_grid; const BScaleType* p_b_scale_grid; const AElementwiseOperation a_element_op; const BElementwiseOperation b_element_op; @@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } - if constexpr(is_same_v) + if constexpr(IsBPreShuffled) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; }); } - else if constexpr(is_same_v) + else { - if constexpr(!PermuteB) + if constexpr(is_same_v) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; + }); } - else + else if constexpr(is_same_v) { - const int k0_offset = karg.KRead * karg.N; - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + if constexpr(!PermuteB) + { + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + } + else + { + const int k0_offset = karg.KRead * karg.N; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + } } } - // Calculate B scale offset - if constexpr(is_same_v) + // Calculate A scale offset + if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } - else if constexpr(is_same_v) + else if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } if(k_id < karg.KBatch - 1) @@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array a_k_split_offset; std::array b_k_split_offset; - index_t scale_k_split_offset; // New member for scale matrix offset + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset index_t c_reduce_offset; }; using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping - // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, - const BScaleType* p_b_scale_grid, - index_t block_n_id) + __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto BM = math::integer_divide_ceil(M, ScaleBlockM); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); + } + } - static constexpr auto wmma = - WmmaSelector{}; - static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + template + __device__ static auto + MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id) + { + if constexpr(ck::is_same_v) + { + using AScale = typename BlockwiseGemmPipe::Empty; + return AScale{}; + } + else + { +#if defined(__gfx11__) + // TODO: remove this restriction + static_assert(ScaleBlockM >= MPerWmma, + "ScaleBlockM must be greater equal than MPerWmma"); +#endif + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); - static constexpr auto ScaleSliceSizeN = NRepeat; - static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); - constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + constexpr auto wmma = + WmmaSelector{}; + constexpr auto RegSizePerWmmaFull = + wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number; + constexpr auto RegSizePerWmma = + math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM); - auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + - (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; - auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, - b_thread_offset_k / ScaleBlockK)); + constexpr auto ScaleSliceSizeM = + ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma + : math::integer_divide_ceil(MPerBlock, ScaleBlockM); + constexpr auto ScaleSliceStrideM = + math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); - using BScale = - typename BlockwiseGemmPipe::template BScale; + auto a_thread_offset_m = + ((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) / + math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) + + (get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM; - return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? RegSizePerWmma + : ScaleSliceSizeK; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0)); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + using AScale = + typename BlockwiseGemmPipe::template ABScale; + + return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf}; + } + } + + __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) + { + const auto BN = math::integer_divide_ceil(N, ScaleBlockN); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); + } + } + + template + __device__ static auto + MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id) + { + if constexpr(ck::is_same_v) + { + using BScale = typename BlockwiseGemmPipe::Empty; + return BScale{}; + } + else + { + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); + + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto ScaleSliceSizeN = + ScaleBlockN < NPerWmma ? NRepeat + : math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr auto ScaleSliceStrideN = + math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma) / + ScaleBlockN; + + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? 1 : ScaleSliceSizeK; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template ABScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } } __device__ static index_t GetKBlockPerScale() { - return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + if constexpr(ck::is_same_v && ck::is_same_v) + { + return 0; + } + else + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } } template ( @@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // B Scale grid - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), - math::integer_divide_ceil(problem.K, ScaleBlockK)), - make_tuple(problem.StrideScaleB, 1)); - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + // AScale struct + auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id); + // BScale struct - auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id); const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, - epilogue_args); + epilogue_args, + k_id); } // NOTE: Wrapper function to have __global__ function in common @@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, - EpilogueArgument& epilogue_args) + EpilogueArgument& epilogue_args, + const index_t k_id = 0) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale splitk_batch_offset.b_k_split_offset[i]; }); + const AScaleType* p_a_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_a_scale_grid_ptr = karg.p_a_scale_grid; + } + else + { + p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset; + } + + const BScaleType* p_b_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_b_scale_grid_ptr = karg.p_b_scale_grid; + } + else + { + p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset; + } + Run( p_as_grid_splitk, p_bs_grid_splitk, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_a_scale_grid_ptr, + p_b_scale_grid_ptr, p_shared, karg, karg.a_element_op, karg.b_element_op, karg.cde_element_op, - epilogue_args); + epilogue_args, + k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 04d1d98448..81aa1ac986 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args, k_id); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + template ( - karg.p_a_grid, - karg.p_b_grid, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } } - __host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K) + __host__ __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { const auto BM = math::integer_divide_ceil(M, ScaleBlockM); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); } } - __host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K) + __host__ __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) { const auto BN = math::integer_divide_ceil(N, ScaleBlockN); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); } } @@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, index_t KBatch_) : M{M_}, N{N_}, @@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 StrideB{StrideB_}, StrideDs{StrideDs_}, StrideC{StrideC_}, + StrideScaleA{StrideScaleA_}, + StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB; std::array StrideDs; index_t StrideC; - + index_t StrideScaleA; + index_t StrideScaleB; index_t KBatch; index_t MPadded; index_t NPadded; @@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + StrideScaleA_, + StrideScaleB_, + k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_ds_grid{}, @@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_k_split_offset = blockIdx.z * karg.KRead; } + // Calculate A scale offset + if constexpr(is_same_v) + { + scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + else if constexpr(is_same_v) + { + scale_a_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t a_k_split_offset; index_t b_k_split_offset; + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K); - const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K); + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp index faf10c2cce..d4ddbafeee 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp @@ -16,7 +16,231 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA_FP8 +// Row, Col +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Row, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Col, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); +#endif +#ifdef CK_USE_XDL // Row, Col void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif template -struct DeviceOperationInstanceFactory, - CLayout, - A0DataType, - A1DataType, - B0DataType, - B1DataType, - Tuple<>, - CDataType, - 1, - 128, - 128, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough>> +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif +#ifdef CK_USE_WMMA_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScale, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> { using DeviceOp = DeviceGemmMultipleD_ABScale; + PassThrough, + PassThrough, + PassThrough>; static auto GetInstances() { std::vector> op_ptrs; #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL if constexpr(is_same_v && is_same_v && is_same_v) { @@ -328,6 +655,33 @@ struct DeviceOperationInstanceFactory, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA_FP8 #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp index a8d9545194..d660c18fd0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp @@ -17,6 +17,47 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + +#ifdef CK_USE_XDL void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK at the moment +#endif // CK_USE_XDL + +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + + return op_ptrs; + } +}; template > op_ptrs; +#ifdef CK_USE_XDL #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) if constexpr(is_same_v && is_same_v && is_same_v) @@ -162,6 +280,35 @@ struct DeviceOperationInstanceFactory< } } #endif +#endif + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK instances + using Wrapper = DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index ef037526ca..575e14d5bb 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -103,6 +103,16 @@ function(add_instance_library INSTANCE_NAME) message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() + # Do not build gemm_ab_scale_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_ab_scale") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_ab_scale_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + # Do not build gemm_blockscale_wp_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_blockscale_wp") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_blockscale_wp_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() # Only build tf32 instances for gfx942 & gfx950 if(source_name MATCHES "_tf32_") if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) @@ -300,7 +310,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found gemm_multiply_multiply instances, but gfx94/gfx95/gfx11/gfx12 not on the target list. Skipping. ${cmake_instance}") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale|gemm_ab_scale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) message(DEBUG "Found gemm_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index a315db8bdd..0512b01175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -1,21 +1,38 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_AB_SCALE_INSTANCES) list(APPEND GEMM_AB_SCALE_INSTANCES # Row, Col + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Row, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Col, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -27,11 +44,13 @@ set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_s set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Row, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Col, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..a4058ca1c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..ad0667dd10 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..dbdfd41e32 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..1380df5291 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..90dbb9c9d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..c45adb91c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..766279520a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..b837c35810 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..2fc87ba6ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..2188a64c98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..cc1be58946 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 1, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 2, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..3c140ef980 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..d68b755506 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..5822fd0b2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..f4661891d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index b37a22d895..dd7596447e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") set(GEMM_BLOCKSCALE_WP_INSTANCES) @@ -10,6 +10,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp ) check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..023d1ac2b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..59fe63421a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp new file mode 100644 index 0000000000..2b5670ead3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index 5396a52e21..f3055575ea 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -109,8 +109,8 @@ bool profile_gemm_ab_scale_impl(int do_verification, case 1: a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -302,7 +302,7 @@ bool profile_gemm_ab_scale_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << ", KBatch " << KBatch << std::endl; if(tflops > best_tflops) { diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 49fef5a0fc..8642cc59e6 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -29,7 +29,7 @@ void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K { int KPack = 16; int NLane = NXdl; - int KLane = 64 / NLane; + int KLane = ck::get_warp_size() / NLane; int K0 = K / (KLane * KPack); // K -> K0 KLane KPack diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b7db14945d..802f29024c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -261,6 +261,7 @@ add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_universal_preshuffle) +add_subdirectory(gemm_ab_scale) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/gemm_ab_scale/CMakeLists.txt b/test/gemm_ab_scale/CMakeLists.txt new file mode 100644 index 0000000000..21203aafaa --- /dev/null +++ b/test/gemm_ab_scale/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_ab_scale test_gemm_ab_scale.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_ab_scale PRIVATE utility device_gemm_ab_scale_instance) + endif() +endif() diff --git a/test/gemm_ab_scale/test_gemm_ab_scale.cpp b/test/gemm_ab_scale/test_gemm_ab_scale.cpp new file mode 100644 index 0000000000..01c3e2ffdb --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_ab_scale_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; +using F8 = ck::f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmABScale_MK_NK : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_MK_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_KM_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, ComputeDataType, EDataType + std::tuple< F8, F32, F8, F32, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmABScale_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_KM_KN, KernelTypes); + +// Row Col +TYPED_TEST(TestGemmABScale_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Row Row +TYPED_TEST(TestGemmABScale_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Col Row +TYPED_TEST(TestGemmABScale_KM_KN, SmallM) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, SmallMPadK) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, MidLargeM) +{ + std::vector Ms{128, 256}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideE); + } +} diff --git a/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp new file mode 100644 index 0000000000..b54e5ce2e5 --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_ab_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmABScale : public testing::Test +{ + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using A0DataType = std::tuple_element_t<3, Tuple>; + using A1DataType = std::tuple_element_t<4, Tuple>; + using B0DataType = std::tuple_element_t<5, Tuple>; + using B1DataType = std::tuple_element_t<6, Tuple>; + using ComputeDataType = std::tuple_element_t<7, Tuple>; + using EDataType = std::tuple_element_t<8, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockM = 1; + static constexpr ck::index_t ScaleBlockN = 128; + static constexpr ck::index_t ScaleBlockK = 128; + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideE, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_ab_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideE, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt index a095968035..a0750255d1 100644 --- a/test/gemm_blockscale_wp/CMakeLists.txt +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") - add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + add_gtest_executable(test_gemm_blockscale_wp_fp8 test_gemm_blockscale_wp_fp8.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + target_link_libraries(test_gemm_blockscale_wp_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) endif() endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp similarity index 100% rename from test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp rename to test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp From 715671e419cbbebe72109ceeeed9d582cca34d02 Mon Sep 17 00:00:00 2001 From: eliotwang <46883838+eliotwang@users.noreply.github.com> Date: Thu, 11 Dec 2025 23:20:29 +0800 Subject: [PATCH 41/65] Bf16*fp4 gemm (#2801) * support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * rebase to new develop * fix clang format * update code according to reviewer's comment * Update README.md * update code according to reviewer's comment * update code according to reviewer's comment * Update CMakeLists.txt * Update README.md * Update CMakeLists.txt * Delete files * Delete files * Add unit tests * Update test_gemm_quant_base.hpp * merge bf16*fp4 example to develop branch * fix clang format * fix clang format * Update CMakeLists.txt * fix ci test * fix clang format * resolve conflicts --------- Co-authored-by: eliotwang Co-authored-by: ShaoChunLee Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng Co-authored-by: Thomas Ning --- .../38_block_scale_gemm/CMakeLists.txt | 1 + example/ck_tile/38_block_scale_gemm/README.md | 4 +- .../gemm_bquant_quantgrouped_bf16mxfp4.cpp | 41 ++ .../38_block_scale_gemm/gemm_quant.cpp | 5 +- .../38_block_scale_gemm/gemm_utils.hpp | 6 +- .../run_gemm_quant_example.inc | 121 ++-- .../core/arch/amd_buffer_addressing.hpp | 7 +- include/ck_tile/host/check_err.hpp | 32 +- .../ck_tile/host/reference/reference_gemm.hpp | 57 ++ include/ck_tile/ops/common/utils.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 14 +- .../block/block_universal_gemm_as_bs_cr.hpp | 6 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 21 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 44 +- include/ck_tile/ops/gemm_quant.hpp | 3 + .../gemm_quant/kernel/gemm_quant_kernel.hpp | 49 +- .../gemm_mxfp4_pipeline_ag_bg_cr_base.hpp | 59 ++ .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 140 ++++ .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 665 ++++++++++++++++++ test/ck_tile/gemm_block_scale/CMakeLists.txt | 0 .../gemm_block_scale/test_gemm_quant_base.hpp | 6 +- .../test_gemm_quant_bquant.cpp | 6 + .../test_gemm_quant_fixtures.hpp | 109 ++- 23 files changed, 1260 insertions(+), 137 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp mode change 100644 => 100755 test/ck_tile/gemm_block_scale/CMakeLists.txt diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index d6b63dc47b..40f06ec97a 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -16,6 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 3a30c2bad3..eb36ae5800 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM. - **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading - **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension. -- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix). +- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)). - **Validation**: CPU/GPU validation and error tolerance options. ## build @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp new file mode 100644 index 0000000000..a022ce18e1 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 45d2151d5e..669bce2995 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4") + "bf8i4 or bf16fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") @@ -97,6 +97,8 @@ void bquant_quantgrouped_fp8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshufflequant_instance_factory( @@ -128,6 +130,7 @@ int main(int argc, char* argv[]) bquant_quantgrouped_bf8_instance_factory(lut); bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf16fp4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); bquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 2b2333b04c..aabbfff3bd 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -69,8 +69,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 8a0dd9bc08..fa5e1f12e3 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -136,9 +136,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t, ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>; constexpr bool TiledPermuteN = (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -147,28 +151,31 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + typename TypeConfig::ADataType, + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -205,7 +212,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + std::is_same_v ? args.K / 2 + : args.K, + args.N, + args.stride_B, + is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -427,7 +438,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -454,8 +469,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -499,13 +517,22 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) @@ -721,13 +748,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant( + a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -787,16 +824,18 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && + QuantMode == ck_tile::QuantType::RowColQuant || + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + "Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm"); } if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + std::is_same_v || + std::is_same_v) { std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8830adfdd9..9c2ce62856 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1550,9 +1550,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index ac388992d1..a1be8027b2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -52,9 +52,19 @@ template ::value, - "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_any_of::value) @@ -113,9 +123,19 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { - static_assert( - is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 883b08fcaa..0aa296b8d9 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -246,6 +246,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor& a_m_k make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + AccDataType pasual = 0; + for(std::size_t k = 0; k < (K / 2); k++) + { + using ComputeType = float; + auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; + ComputeType v_a_0, v_a_1; + ComputeType v_b_0, v_b_1; + + v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); + v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); + + if constexpr(std::is_same_v) + { + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); + + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + + v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; + v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + } + + pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + v_acc += pasual; + } + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9a7876f6a5..ad1862306a 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,11 +92,17 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; + using ATypeToUse = std::conditional_t || + std::is_same_v, + BDataType, + ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v || + std::is_same_v, + ADataType, + BDataType>; + using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 8541ffa3a9..f6e26ad206 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index f39d41a653..343e37ed66 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - using ALayout = remove_cvref_t{}, AsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = + std::conditional_t, ADataType, BInDataType>; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase }(); auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + using BLdsDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsLoadTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); + else return BLdsLoadTileDistr{}; }(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 76341af70b..a45d41189b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; if constexpr(Problem::FixedVectorSize) { @@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BDataType = remove_cvref_t; + constexpr index_t KPerBlock = std::is_same_v + ? Problem::BlockGemmShape::kK / 2 + : Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + std::is_same_v + ? 4 + : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - using BLayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = - integer_least_multiple(sizeof(typename Problem::BDataType) * - Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, - 16); + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); return smem_size_b; } @@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = WarpGemmDispatcher( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -885,10 +893,16 @@ struct QuantGemmKernel const auto& b_tensor_view = views.at(I2); if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -1020,10 +1034,17 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..58019d703e --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using QuantGroupSize = remove_cvref_t; + + using BQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + + static_assert(NPerBlock % QuantGroupSize::kN == 0, + "NPerBlock must be a multiple of QuantGroupSize::kN"); + static_assert(KPerBlock % QuantGroupSize::kK == 0, + "KPerBlock must be a multiple of QuantGroupSize::kK"); + + // Create DRAM tile window for BQ + template + CK_TILE_DEVICE constexpr auto + GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + bq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBQDramTileDistribution()); + return bq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..6ce2ff10fa --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + // using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + using WarpGemm = WarpGemmDispatcher; + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..c113521d6b --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register + +template +struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t BQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeBQ() + { + return Policy::template GetVectorSizeBQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t BQ_Buffer_Load_Inst_Num = + NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "BQ vector size: " << GetVectorSizeBQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_bq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Bq block window has incorrect lengths for defined BqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert( + is_b_row_major + ? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + // int b_block_stride = 0; + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load, (kN, kK/2) + // B LDS tile window for store, (kN, kK) + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // B scale DRAM tile window for load + // auto b_scale_copy_dram_window = + // make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + // bq_dram_block_window_tmp.get_window_lengths(), + // bq_dram_block_window_tmp.get_window_origin(), + // Policy::template GetBQDramLoadWindow()); + auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); + + auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_fp4_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + + constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // BDataType + auto b_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + constexpr auto idx1_js = tile_distributed_index<0>{}; + constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + block_sync_lds(); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = + type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + // b_block_stride +=1; + } while(i < (num_loop - 1)); + } + // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t n = 0) const + { + ck_tile::ignore = n; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDqDataType& b) { return b; }, + bq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 39a7c66f38..fe5d2bd7e1 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -131,8 +131,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType_, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp index ef0d41909b..ec123364cb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp @@ -16,9 +16,12 @@ using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using BF16 = ck_tile::bf16_t; +using UInt8 = ck_tile::pk_fp4_raw_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize32 = ck_tile::QuantGroupShape>; // 2d block sizes for BQuant using GroupSize2D8N = ck_tile::QuantGroupShape>; @@ -42,6 +45,9 @@ using BQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, + std::tuple, + + std::tuple, // 2d cases with grouping also on the n axis std::tuple, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index bf9c7a138d..4f2edb3609 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -60,6 +60,13 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Tile = 128; }; +struct GemmConfigMxFp4 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; @@ -403,7 +410,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -414,15 +422,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + std::is_same_v ? K / 2 : K, + N, + stride_B, + this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); + } + else + { + ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + } + // Allocate device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); @@ -501,13 +521,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); // Get device result ck_tile::HostTensor c_m_n_dev_result( @@ -580,33 +609,37 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase; - using GemmPipeline = - std::conditional_t, - ck_tile::WPQuantBPipelineAgBgCrV2>; + using GemmPipeline = std::conditional_t< + PreshuffleB == false, + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, + ck_tile::WPQuantBPipelineAgBgCrV2>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - Base::M_Warp, - Base::N_Warp, - Base::M_Warp_Tile, - Base::N_Warp_Tile, - Base::K_Warp_Tile, - false, // transpose_c - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledMMAPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + ADataType, + BDataType>, + ck_tile::tuple<>, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false, // transpose_c + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledMMAPermuteN>>; using Kernel = ck_tile::QuantGemmKernel Date: Thu, 11 Dec 2025 08:09:29 -0800 Subject: [PATCH 42/65] Fix compilation errors with latest clang22 version. (#3396) * remove target attributes from deduction guides * switch CK_TILE_HOST_DEVICE_EXTERN based on clang version --- include/ck_tile/core/config.hpp | 4 ++++ include/ck_tile/core/numeric/math.hpp | 21 +++++++------------ .../core/utility/unary_element_function.hpp | 3 +-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 678a2fbfff..0e7d1def75 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -39,8 +39,12 @@ #define CK_TILE_DEVICE inline __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_DEVICE_EXTERN __device__ +#if __clang_major__ < 22 #define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ #else +#define CK_TILE_HOST_DEVICE_EXTERN +#endif +#else #define CK_TILE_HOST inline #define CK_TILE_DEVICE inline #define CK_TILE_HOST_DEVICE inline diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 57f3953514..8a0e3b3408 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -41,9 +41,8 @@ struct scales Scale lhs_; }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ scales(Scale) -> scales; +CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales; template struct plus @@ -66,8 +65,7 @@ struct plus } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ plus() -> plus; +CK_TILE_HOST_DEVICE_EXTERN plus() -> plus; template struct minus @@ -90,8 +88,7 @@ struct minus } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ minus() -> minus; +CK_TILE_HOST_DEVICE_EXTERN minus() -> minus; template struct multiplies @@ -114,8 +111,7 @@ struct multiplies } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ multiplies() -> multiplies; +CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies; template struct maximize @@ -345,8 +341,7 @@ struct equal } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ equal() -> equal; +CK_TILE_HOST_DEVICE_EXTERN equal() -> equal; template <> struct equal @@ -387,8 +382,7 @@ struct less } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less() -> less; +CK_TILE_HOST_DEVICE_EXTERN less() -> less; template struct less_equal @@ -411,8 +405,7 @@ struct less_equal } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less_equal() -> less_equal; +CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal; template <> struct less_equal diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index b195275bdc..595b8522da 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -47,9 +47,8 @@ struct composes F f_; }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ composes(Ts&&...) -> composes...>; +CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes...>; template struct saturates From 4dcc3e59c1c0195dae7ee9da9ab76d18a4cafe9f Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 20:25:29 +0400 Subject: [PATCH 43/65] chore: update copyright header for misc files (#3402) * chore: update copyright header for misc files * fix: typo in kernel resulting in ci failure --- docs/conceptual/ck_tile/convert_mermaid_to_svg.py | 3 +++ docs/conceptual/ck_tile/convert_raw_html_to_commented.py | 3 +++ docs/conceptual/ck_tile/update_diagrams.py | 3 +++ example/test_old_ck_gpu_reference.cpp | 2 +- experimental/builder/test/test_ckb_conv_builder.cpp | 2 ++ include/ck_tile/ref/conv_common.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp | 2 +- .../device_grouped_gemm_wmma_splitk_instance.hpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp | 2 +- .../test_gemm_quant_bquant_preshuffle.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp | 2 +- test/ck_tile/utility/test_fill.cpp | 2 +- test/ck_tile/warp_gemm/CMakeLists.txt | 3 +++ test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp | 2 +- .../practice_gemm_host_pipeline_agmem_bgmem_creg.hpp | 8 ++++---- 19 files changed, 31 insertions(+), 17 deletions(-) diff --git a/docs/conceptual/ck_tile/convert_mermaid_to_svg.py b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py index 1d62405e53..2bfaffdb57 100644 --- a/docs/conceptual/ck_tile/convert_mermaid_to_svg.py +++ b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Script to convert all mermaid diagrams in CK Tile docs to SVGs. This script: diff --git a/docs/conceptual/ck_tile/convert_raw_html_to_commented.py b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py index e90bf9def0..8e4a849e7f 100644 --- a/docs/conceptual/ck_tile/convert_raw_html_to_commented.py +++ b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """Convert raw HTML mermaid blocks to commented format for SVG conversion.""" import os diff --git a/docs/conceptual/ck_tile/update_diagrams.py b/docs/conceptual/ck_tile/update_diagrams.py index 2fbe2ef5a9..f78599010e 100644 --- a/docs/conceptual/ck_tile/update_diagrams.py +++ b/docs/conceptual/ck_tile/update_diagrams.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Helper script to update SVG diagrams from commented mermaid sources in RST files. diff --git a/example/test_old_ck_gpu_reference.cpp b/example/test_old_ck_gpu_reference.cpp index 0bcf43d20b..9f12eaea4d 100644 --- a/example/test_old_ck_gpu_reference.cpp +++ b/example/test_old_ck_gpu_reference.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. // Standalone test program for Old CK GPU references // Tests naive_conv_fwd (existing) and future backward ops diff --git a/experimental/builder/test/test_ckb_conv_builder.cpp b/experimental/builder/test/test_ckb_conv_builder.cpp index e69de29bb2..81e63887c1 100644 --- a/experimental/builder/test/test_ckb_conv_builder.cpp +++ b/experimental/builder/test/test_ckb_conv_builder.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT diff --git a/include/ck_tile/ref/conv_common.hpp b/include/ck_tile/ref/conv_common.hpp index ed43e87b14..50ae18eb99 100644 --- a/include/ck_tile/ref/conv_common.hpp +++ b/include/ck_tile/ref/conv_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp index a5f6a697f2..f75bdda912 100644 --- a/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp index 2ac9c19892..0839074dd4 100644 --- a/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp index 720fa40297..f582fcd71a 100644 --- a/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp index 6d5da9208b..d0de1c859b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp index 9ba0b9c804..b6e69cd649 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp index ec123364cb..4b1ad068a7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 3a62fc091a..ae01bddf96 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp index 5a58ed886a..bb0fa21899 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp index 0fa4048dab..8b4c90f8b9 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp index 18f42c4ad0..3633f8bbff 100644 --- a/test/ck_tile/utility/test_fill.cpp +++ b/test/ck_tile/utility/test_fill.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host/fill.hpp" #include "ck_tile/host/joinable_thread.hpp" diff --git a/test/ck_tile/warp_gemm/CMakeLists.txt b/test/ck_tile/warp_gemm/CMakeLists.txt index 664ebc003b..5079741e1b 100644 --- a/test/ck_tile/warp_gemm/CMakeLists.txt +++ b/test/ck_tile/warp_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp) endif() diff --git a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp index 7878fda618..47fa1ff43e 100644 --- a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp +++ b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck_tile/host.hpp" diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index 15c1743a86..45f439e8fa 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -28,9 +28,9 @@ struct PracticeGemmHostPipeline { // Size of the entire problem - const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K - const auto N = c_dram_ref.get_tensor_descriptor().get_length(number<1>{}); // M x N - const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K // Size of the block tile const auto MPerBlock = BlockTile::at(number<0>{}); @@ -83,7 +83,7 @@ struct PracticeGemmHostPipeline __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; const auto c_block_tile = block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); - auto c_window = make_tile_window(c_dram_ref, + auto c_window = make_tile_window(c_dram, make_tuple(number{}, number{}), {tile_origin_m, tile_origin_n}); store_tile(c_window, c_block_tile); From 45c4ea510c76366de9d1c102bd05bc02cf7eecc8 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 22:34:15 +0400 Subject: [PATCH 44/65] chore: add copyright to pass the CI (#3407) --- .../38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp | 2 +- .../gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp | 2 +- .../gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 2 +- .../ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp index a022ce18e1..31d263ea1d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. #include "run_gemm_quant_example.inc" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp index 58019d703e..95122630ee 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp index 6ce2ff10fa..7a2d1db2c8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index c113521d6b..b63a312489 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once From ff194a427129beabd419904ee173c221bcc2a5e5 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 22:39:20 +0400 Subject: [PATCH 45/65] build: Hot fix to reduce massive build time by just disabling the instances (#3408) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- test/ck_tile/grouped_gemm_quant/CMakeLists.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 7a7ae77730..892e123d3d 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -7,17 +7,17 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - # Split into three separate test executables for faster parallel compilation - add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + # Split into three separate test executables for faster parallel compilation + add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() From 4011dbfec31a711aaa4c1071c31bdc55f9b7974a Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:23:43 -0800 Subject: [PATCH 46/65] [CK-Tile] fixup codegen for tile engine ops gemm multid and gemm preshuffle (#3383) * fixup gemm multi-d and preshuffle in tile engine codegen --------- Co-authored-by: Thrupti Raj Lakshmana Gowda --- .../gemm_multi_d_instance_builder.py | 64 ++++++----------- .../gemm_preshuffle_instance_builder.py | 68 +++++++------------ 2 files changed, 44 insertions(+), 88 deletions(-) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 06da7ea8a2..f04c2a2c96 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -452,34 +452,23 @@ struct SelectedKernel {{ using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; static float launch(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {{ - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler>; - float ave_time{{0}}; + using GemmPipeline = {pipeline_impl_map.get(pipeline)}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; + const auto Run = [&](const auto memory_operation_) {{ [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - ck_tile::TileGemmUniversalTraits, - scheduler, - has_hot_loop_v, - tail_number_v>; - using GemmPipeline = {pipeline_impl_map.get(pipeline)}; - - // Epilogue + // Epilogue """ # Add epilogue configuration based on type @@ -552,29 +541,18 @@ struct SelectedKernel {{ // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); - - return ave_time; }}; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ - if(args.k_batch == 1) {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} else {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} - }}; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) {{ + return Run(ck_tile::integral_constant{{}}); + }} else {{ + return Run(ck_tile::integral_constant{{}}); + }} }} }}; """ diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 654a039b9c..62c239590a 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -484,35 +484,24 @@ struct SelectedKernel {{ using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}; static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{{0}}; + constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")}; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}; + + const auto Run = [&](const auto memory_operation_) {{ [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - ck_tile::TileGemmUniversalTraits, - scheduler, - has_hot_loop_v, - tail_number_v>; - - using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}; - // Epilogue """ @@ -590,29 +579,18 @@ struct SelectedKernel {{ // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - - return ave_time; }}; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ - if(args.k_batch == 1) {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} else {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} - }}; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) {{ + return Run(ck_tile::integral_constant{{}}); + }} else {{ + return Run(ck_tile::integral_constant{{}}); + }} }} }}; """ From 8d7a4e0c73e1d2741fecea200f14bda1dcacc8f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:09:40 -0800 Subject: [PATCH 47/65] Bump rocm-docs-core[api_reference] from 1.31.0 to 1.31.1 in /docs/sphinx (#3410) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.31.0 to 1.31.1. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.31.0...v1.31.1) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.31.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b607daa9ff..b1ab09e6f7 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.31.0 +rocm-docs-core[api_reference]==1.31.1 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index fce859cf0e..099e9e439f 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.31.0 +rocm-docs-core[api-reference]==1.31.1 # via -r requirements.in rpds-py==0.24.0 # via From 98696413248802ab8007b709e5fc76988b5600b6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 12 Dec 2025 09:27:12 -0800 Subject: [PATCH 48/65] disable test_tile_gemm_quant_bquant_preshuffle (#3420) --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8309b14f0a..2b0ffaafa2 100755 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -24,10 +24,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # BQuant tests (with PreshuffleB) - add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle - test_gemm_quant_bquant_preshuffle.cpp - ) - target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # disabling this test until it can be built within reasonable time! + # currently taking ~50 minutes on gfx12! + #add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle + # test_gemm_quant_bquant_preshuffle.cpp + #) + #target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # RowColQuant tests add_gtest_executable(test_tile_gemm_quant_rowcol From fc7bf0ab1c5ed28e5962681007f84a2e8d3ee051 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Sat, 13 Dec 2025 01:28:37 +0800 Subject: [PATCH 49/65] [CK_TILE] Port hw independent changes from internal repo to develop branch (#3301) * [CK_TILE] Port hw independent changes from internal repo to develop branch It includes PR#96, #114, #120, #121. * correct rebase error --- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- example/ck_tile/03_gemm/run_gemm_example.inc | 4 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 2 + .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 1 + .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 7 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 2 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 32 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 25 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 +- .../ops/reduce/block/block_reduce2d.hpp | 2 +- include/ck_tile/utility/json_dump.hpp | 475 +++++++++--------- .../epilogue/test_cshuffle_epilogue_util.hpp | 2 +- test/ck_tile/gemm_multi_abd/CMakeLists.txt | 2 +- .../test_gemm_multi_abd_cshuffle.cpp | 15 +- .../test_gemm_multi_abd_default2d.cpp | 8 +- .../test_gemm_multi_abd_util.hpp | 36 +- .../test_gemm_pipeline_util.hpp | 24 +- .../grouped_gemm_multi_d/CMakeLists.txt | 2 +- .../test_grouped_gemm_multi_d.cpp | 53 +- .../grouped_gemm_preshuffle/CMakeLists.txt | 2 +- .../test_grouped_gemm_preshuffle.cpp | 12 +- .../test_grouped_gemm_preshuffle_util.hpp | 62 ++- 22 files changed, 465 insertions(+), 310 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index b25aec101b..47c47334e7 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -459,7 +459,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -auto create_args() +inline auto create_args() { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3840", "m dimension") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c4f100b36b..78f3a9b0b3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } -std::tuple -parse_gemm_size(ck_tile::ArgParser& arg_parser) +std::tuple inline parse_gemm_size( + ck_tile::ArgParser& arg_parser) { ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 3445f063f5..52b2b86574 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -986,6 +986,8 @@ struct MoeSortingKernel p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; } } + __syncthreads(); + smem_cumdup(num_experts) = smem_cumsum(num_experts); // fill the p_sorted_token_ids/p_sorted_weights diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 63993c5eb6..838fc236d2 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -561,6 +561,7 @@ struct GroupedGemmKernel const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_sync_lds(); block_id = block_id + grid_size; // advance to next block // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR if(block_id >= cum_grid_size) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 91f1358321..6130107cfe 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -631,6 +631,7 @@ struct StreamKKernel tile_idx += kargs.tile_partitioner.get_grid()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); + block_sync_lds(); } // Stream-K section @@ -679,8 +680,8 @@ struct StreamKKernel { hipDeviceProp_t dev_prop; hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + ck_tile::hip_check_error(hipGetDevice(&dev)); + ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); int num_cu = dev_prop.multiProcessorCount; return num_cu; @@ -700,7 +701,7 @@ struct StreamKKernel constexpr int min_block_per_cu = 1; const auto kernel = kentry; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); return max(occupancy, 1); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 4b28ac3f12..866a4cc693 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -280,7 +280,7 @@ struct UniversalGemmKernel using Kernel = UniversalGemmKernel; const auto kernel = kentry<1, Kernel, KernelArgs>; int occupancy; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0)); const int grid_size = get_available_compute_units(s) * occupancy; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 16ed8de22f..936c38ddf3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -9,11 +9,35 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV1 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; + // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV1 +struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1 template static constexpr index_t GetVectorSizeA() { - return Problem::VectorSizeA; + return Policy::template GetVectorSizeA(); } template static constexpr index_t GetVectorSizeB() { - return Problem::VectorSizeB; + return Policy::template GetVectorSizeB(); } - static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 5dbcde80a6..c711c768ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -9,11 +9,34 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV2 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV2 +struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index d76fd6dc0f..47607a40f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -43,13 +43,14 @@ template + bool Preshuffle_ = false, + int VectorSize_ = 16> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using AsLayout = AsLayout_; diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index cbf4afefb2..ba6ed27651 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync if constexpr(num_reduce_warps == 1) return; - + block_sync_lds(); // Each warp's lane 0 writes its partial results to shared memory const index_t smem_offset = warp_id; if(lane_id == 0) diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index b5bab28cac..03e97c0b76 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_gemm_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int batch_stride_A, - int batch_stride_B, - int batch_stride_C, - int batch_count, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "batched_gemm_basic") +inline void dump_batched_gemm_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int batch_stride_A, + int batch_stride_B, + int batch_stride_C, + int batch_count, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "batched_gemm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_flatmm_json_results(const std::string& json_filename, - const std::string& datatype, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int kbatch, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "flatmm_basic") +inline void dump_flatmm_json_results(const std::string& json_filename, + const std::string& datatype, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int kbatch, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "flatmm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideD0, - int StrideD1, - int StrideE, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "gemm_multi_d_fp16") +inline void +dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "gemm_multi_d_fp16") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_elementwise_json_results(const std::string& json_filename, - const std::string& prec, - int grid_size, - int block_size, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "elementwise") +inline void dump_elementwise_json_results(const std::string& json_filename, + const std::string& prec, + int grid_size, + int block_size, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "elementwise") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_layernorm2d_fwd_json_results(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - const std::string& prec_sm, - const std::string& prec_sy, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "layernorm2d_fwd") +inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + const std::string& prec_sm, + const std::string& prec_sy, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "layernorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_permute_json_results(const std::string& json_filename, - const std::string& data_type, - bool pass, - float ave_time, - float tflop, - float gb_per_sec, - const std::string& kernel_name = "permute") +inline void dump_permute_json_results(const std::string& json_filename, + const std::string& data_type, + bool pass, + float ave_time, + float tflop, + float gb_per_sec, + const std::string& kernel_name = "permute") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_topk_softmax_json(const std::string& json_filename, - const std::string& input_prec, - const std::string& weight_prec, - int tokens, - int experts, - int topk, - int stride_input, - int stride_output, - float ave_time, - float tflop, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "topk_softmax") +inline void dump_topk_softmax_json(const std::string& json_filename, + const std::string& input_prec, + const std::string& weight_prec, + int tokens, + int experts, + int topk, + int stride_input, + int stride_output, + float ave_time, + float tflop, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "topk_softmax") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_rmsnorm2d_fwd_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - int use_model_sensitive_rmsnorm, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "rmsnorm2d_fwd") +inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + int use_model_sensitive_rmsnorm, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "rmsnorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_add_rmsnorm2d_rdquant_fwd_json( - const std::string& json_filename, - const std::string& input_data_type, - const std::string& quantized_data_type, - int m, - int n, - int stride, - float epsilon, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") +inline void +dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename, + const std::string& input_data_type, + const std::string& quantized_data_type, + int m, + int n, + int stride, + float epsilon, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json( END_JSON_DUMP_FILE(); } -void dump_smoothquant_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int y_stride, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "smoothquant") +inline void dump_smoothquant_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int y_stride, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_sorting_json(const std::string& json_filename, - const std::string& index_prec, - const std::string& weight_prec, - const std::string& workspace_size, - int dispatch_policy, - int tokens, - int num_experts, - int topk, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "moe_sorting") +inline void dump_moe_sorting_json(const std::string& json_filename, + const std::string& index_prec, + const std::string& weight_prec, + const std::string& workspace_size, + int dispatch_policy, + int tokens, + int num_experts, + int topk, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "moe_sorting") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_transpose_json(const std::string& json_filename, - int N, - int C, - int H, - int W, - const std::string& layout_in, - const std::string& layout_out, - const std::string& prec, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "batched_transpose") +inline void dump_batched_transpose_json(const std::string& json_filename, + int N, + int C, + int H, + int W, + const std::string& layout_in, + const std::string& layout_out, + const std::string& prec, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "batched_transpose") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_smoothquant_json(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - int tokens, - int hidden_size, - int stride, - int experts, - int topk, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "moe_smoothquant") +inline void dump_moe_smoothquant_json(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + int tokens, + int hidden_size, + int stride, + int experts, + int topk, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "moe_smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fused_moe_json(const std::string& json_filename, - const std::string& api_str, - const std::string& prec_str, - int tokens, - bool is_local_token, - int local_tokens, - int experts, - int topk, - int hidden_size, - int intermediate_size, - int stride, - int block_m, - int activation, - bool gate_only, - bool fused_quant, - bool pass, - float ave_time, - float tflops, - float tb_per_sec, - const std::string& kernel_name = "fused_moe") +inline void dump_fused_moe_json(const std::string& json_filename, + const std::string& api_str, + const std::string& prec_str, + int tokens, + bool is_local_token, + int local_tokens, + int experts, + int topk, + int hidden_size, + int intermediate_size, + int stride, + int block_m, + int activation, + bool gate_only, + bool fused_quant, + bool pass, + float ave_time, + float tflops, + float tb_per_sec, + const std::string& kernel_name = "fused_moe") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_fwd_json_results(const std::string& json_filename, - const std::string& prec, - const std::string& mode, - const std::string& io_layout, - int batch, - int nhead, - int nhead_k, - int seqlen_qs, - int seqlen_ks, - int seqlen_kpads, - int hdim_q, - int hdim_v, - float scale_s, - float p_drop, - bool lse, - const std::string& qscale, - const std::string& bias, - const std::string& vlayout, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_fwd") +inline void dump_fmha_fwd_json_results(const std::string& json_filename, + const std::string& prec, + const std::string& mode, + const std::string& io_layout, + int batch, + int nhead, + int nhead_k, + int seqlen_qs, + int seqlen_ks, + int seqlen_kpads, + int hdim_q, + int hdim_v, + float scale_s, + float p_drop, + bool lse, + const std::string& qscale, + const std::string& bias, + const std::string& vlayout, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_bwd_json_results(const std::string& json_filename, - const std::string& data_type, - const std::string& mode, - const std::string& i_perm, - const std::string& o_perm, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - float scale, - const std::string& bias, - bool use_dbias, - float p_drop, - bool s_randval, - bool deterministic, - const std::string& mask, - int mask_left, - int mask_right, - int workspace_size, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_bwd") +inline void dump_fmha_bwd_json_results(const std::string& json_filename, + const std::string& data_type, + const std::string& mode, + const std::string& i_perm, + const std::string& o_perm, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + const std::string& bias, + bool use_dbias, + float p_drop, + bool s_randval, + bool deterministic, + const std::string& mask, + int mask_left, + int mask_right, + int workspace_size, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_bwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 4fdbf23864..9b90110c07 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -130,7 +130,7 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) constexpr index_t kMPerBlock = Problem::kMPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock; - constexpr index_t kBlockSize = Problem::kBlockSize; + index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N << ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index 2dccf9cd60..03759652cd 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -7,7 +7,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 08997529b2..ab00f16632 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -20,20 +20,21 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue enabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> - >; + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index dac33b4656..c4bfc3e7cb 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -20,17 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue disabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index ee045c7f48..8cee050db2 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -23,6 +23,28 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template & args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; +#if CK_TILE_USE_WMMA + using ADataType = + ck_tile::remove_cvref_t{}, AsDataType>>; + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +#else constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif constexpr bool DoubleSmemBuffer = false; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 43a73738d9..7c085b5098 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -13,6 +13,28 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -80,7 +102,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt index f86da3c4d5..5363e365fc 100644 --- a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -9,7 +9,7 @@ endif() # Use standard asm for rtn bf16 conversion instead of turncate list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp) target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index 65c662199b..8d56c274aa 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -29,9 +29,6 @@ template ; - static constexpr int M_Tile_ = M_Tile_val_; - static constexpr int N_Tile_ = N_Tile_val_; - static constexpr int K_Tile_ = K_Tile_val_; - static constexpr int M_Warp_ = M_Warp_val_; - static constexpr int N_Warp_ = N_Warp_val_; - static constexpr int K_Warp_ = K_Warp_val_; - static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_; - static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_; - static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_; + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int M_Warp_ = M_Warp_val_; + static constexpr int N_Warp_ = N_Warp_val_; + static constexpr int K_Warp_ = K_Warp_val_; +#if CK_TILE_USE_WMMA + static constexpr int M_Warp_Tile_ = 16; + static constexpr int N_Warp_Tile_ = 16; + static constexpr int K_Warp_Tile_ = 16; +#else + static constexpr int M_Warp_Tile_ = 32; + static constexpr int N_Warp_Tile_ = 32; + static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8; +#endif static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; @@ -68,21 +71,21 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent // FP16 A/B/D/E - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 // BF16 A/B/D/E - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt index 08b413aea9..3a230aed0c 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp index 623d0152d6..450b7b8f24 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -50,16 +50,16 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>, - - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>, - +#endif + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>, diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 0eb388082b..5628b6feae 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -14,6 +14,9 @@ template constexpr ck_tile::index_t get_k_warp_tile_flatmm() { +#if CK_TILE_USE_WMMA + return 16; +#else #if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; @@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm() else return sizeof(PrecType) == 2 ? 32 : 64; #endif +#endif } template @@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / N_Warp_Tile, + N_Warp_Tile, + k_ / K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = SupportVectorSize16 ? 16 : 8; using GemmShape = ck_tile::TileGemmShape, @@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ false, /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem, ck_tile::sequence, @@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ true, // Enable persistent mode /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem Date: Fri, 12 Dec 2025 19:26:47 +0100 Subject: [PATCH 50/65] Fix compilation ab scale multi target (#3413) --- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index ac5b7dd0c4..0974f45a2b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -527,11 +527,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale } else { -#if defined(__gfx11__) - // TODO: remove this restriction - static_assert(ScaleBlockM >= MPerWmma, - "ScaleBlockM must be greater equal than MPerWmma"); -#endif static_assert( ScaleBlockK >= WmmaSelector:: From 9707ddb444f42b490c73b7884babccde2988ed7e Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:08:26 -0700 Subject: [PATCH 51/65] [CK TILE GEMM STREAMK] update identifier names according to the new code style (#3348) * [CK TILE GEMM STREAMK] update identifier names according to the new code style --- .../ck_tile/40_streamk_gemm/gemm_utils.hpp | 56 +-- .../40_streamk_gemm/run_gemm_example.inc | 380 +++++++++--------- .../40_streamk_gemm/streamk_gemm_basic.cpp | 204 +++++----- 3 files changed, 328 insertions(+), 312 deletions(-) diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index dad31ec637..34c6c6b0ae 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -7,46 +7,46 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -struct GemmConfigBase +struct GemmConfigurationBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; + static constexpr bool PAD_M = true; + static constexpr bool PAD_N = true; + static constexpr bool PAD_K = true; - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; + static constexpr bool PERMUTE_A = false; + static constexpr bool PERMUTE_B = false; - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; + static constexpr bool TRANSPOSE_C = false; + static constexpr bool USE_STRUCTURED_SPARSITY = false; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr int BLOCK_PER_CU = 1; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1; + static constexpr bool PRESHUFFLE = false; + static constexpr bool DOUBLE_SMEM_BUFFER = false; }; -template -struct GemmConfigMemoryInterwave : public GemmConfigBase +template +struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 16; + static constexpr ck_tile::index_t M_TILE = 256; + static constexpr ck_tile::index_t N_TILE = 256; + static constexpr ck_tile::index_t K_TILE = 16; - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; + static constexpr ck_tile::index_t M_WARP = 2; + static constexpr ck_tile::index_t N_WARP = 2; + static constexpr ck_tile::index_t K_WARP = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + static constexpr ck_tile::index_t M_WARP_TILE = 32; + static constexpr ck_tile::index_t N_WARP_TILE = 32; + static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16; - static constexpr bool Persistent = Persistent_; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool PERSISTENT = IsPersistent; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; }; template -struct StreamKGemmTypeConfig +struct StreamKGemmTypeConfiguration { using ADataType = ADataType_; using BDataType = BDataType_; @@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -auto create_args(int argc, char* argv[]) +auto createArgs(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index d18ac2e68a..7442bd33f2 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout) } template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) +auto calculateRtolAtol(const ck_tile::index_t k_dim, + const ck_tile::index_t k_batch, + const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto relative_tolerance = + ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(k_dim, k_batch)); + const auto absolute_tolerance = + ck_tile::get_absolute_threshold( + max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch)); // Calculate error due to multiple WGs working in the same C macro tile - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); + const auto relative_tolerance_split_k = + ck_tile::get_relative_threshold(k_batch); + const auto absolute_tolerance_split_k = + ck_tile::get_absolute_threshold(max_accumulated_value, + k_batch); // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k), + std::max(absolute_tolerance, absolute_tolerance_split_k)); } -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s); + const ck_tile::stream_config& stream_config); -template -std::tuple -invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy) +std::tuple invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory, + ck_tile::DeviceMem& b_k_n_device_memory, + ck_tile::DeviceMem& c_m_n_device_memory, + ck_tile::index_t m_dim, + ck_tile::index_t n_dim, + ck_tile::index_t k_dim, + ck_tile::index_t stride_a, + ck_tile::index_t stride_b, + ck_tile::index_t stride_c, + int warmup_iterations, + int repeat_iterations, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy) { - ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - stride_A, - stride_B, - stride_C}; + ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(), + b_k_n_device_memory.GetDeviceBuffer(), + c_m_n_device_memory.GetDeviceBuffer(), + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c}; - std::tuple ave_time_and_batch; + std::tuple average_time_and_batch; if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } else /*Reduction*/ { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } - return ave_time_and_batch; + return average_time_and_batch; } template -bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, - const ck_tile::HostTensor& c_m_n_ref, - const ck_tile::tuple& rtol_atol, - const char* variant) +bool doVerify(const ck_tile::HostTensor& c_m_n_device_result, + const ck_tile::HostTensor& c_m_n_reference, + const ck_tile::tuple& relative_absolute_tolerances, + const char* variant) { - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_ref, + bool pass = ck_tile::check_err(c_m_n_device_result, + c_m_n_reference, "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); + relative_absolute_tolerances.at(ck_tile::number<0>{}), + relative_absolute_tolerances.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "Relative error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<0>{}) + << " Absolute error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl; std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail") << std::endl; return pass; } -ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy) { if(strategy == "atomic") { @@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string } } -template -int run_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) +int runGemmExampleWithLayouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); + auto [result, arg_parser] = createArgs(argc, argv); if(!result) return -1; - static_assert(!GemmConfig::Preshuffle, "Not implemented"); - static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); - static_assert(!GemmConfig::PermuteA, "Not implemented"); - static_assert(!GemmConfig::PermuteB, "Not implemented"); + static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented"); + static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented"); - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; + using ADataType = typename TypeConfiguration::ADataType; + using BDataType = typename TypeConfiguration::BDataType; + using AccumulatorDataType = typename TypeConfiguration::AccDataType; + using CDataType = typename TypeConfiguration::CDataType; - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t m_dim = arg_parser.get_int("m"); + ck_tile::index_t n_dim = arg_parser.get_int("n"); + ck_tile::index_t k_dim = arg_parser.get_int("k"); - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); + int warmup_iterations = arg_parser.get_int("warmup"); + int repeat_iterations = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); bool flush_cache = arg_parser.get_bool("flush_cache"); - ck_tile::StreamKReductionStrategy reduction_strategy = - get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); + getReductionStrategyValue(arg_parser.get_str("reduction_strategy")); - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout)); + stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout)); + stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{})); - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::HostTensor a_m_k_host( + ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_host( + ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_device_result( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_host); } else if(init_method == 1) { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); + ck_tile::FillMonotonicSeq{}(a_m_k_host); + ck_tile::FillMonotonicSeq{}(b_k_n_host); } else if(init_method == 2) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_host); } else { - a_m_k.SetZero(); - b_k_n.SetZero(); + a_m_k_host.SetZero(); + b_k_n_host.SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + a_m_k_device_memory.ToDevice(a_m_k_host.data()); + b_k_n_device_memory.ToDevice(b_k_n_host.data()); + c_m_n_device_memory.SetZero(); + c_m_n_device_result.SetZero(); + auto [average_time, num_wgs_per_tile] = invokeGemm, + AccumulatorDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_device_memory, + b_k_n_device_memory, + c_m_n_device_memory, + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c, + warmup_iterations, + repeat_iterations, + flush_cache, + reduction_strategy); - auto [ave_time, num_wgs_per_tile] = invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - n_warmup, - n_repeat, - flush_cache, - reduction_strategy); + c_m_n_device_memory.FromDevice(c_m_n_device_result.data()); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K - << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim; + std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim + + sizeof(CDataType) * m_dim * n_dim; + float tflops = static_cast(flop) / 1.E9 / average_time; + float gb_per_sec = num_byte / 1.E6 / average_time; + std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim + << " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << ck_tile::DataTypeTraits::name << " B_Type=" << ck_tile::DataTypeTraits::name << " C_Type=" << ck_tile::DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " - << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time + << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - bool pass = false; // Memory on host to store gpu reference result - ck_tile::HostTensor c_m_n_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_ref.SetZero(); + ck_tile::HostTensor c_m_n_reference( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); + c_m_n_reference.SetZero(); if(arg_parser.get_int("v") == 1) // Validate on the CPU { - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_ref); + ck_tile::reference_gemm( + a_m_k_host, b_k_n_host, c_m_n_reference); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU"); } else if(arg_parser.get_int("v") == 2) // Validate on the GPU { // Memory on device to store gpu reference result - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + ck_tile::DeviceMem c_m_n_gpu_buffer_reference( + c_m_n_reference.get_element_space_size_in_bytes()); + c_m_n_gpu_buffer_reference.SetZero(); + ADataType* d_A = static_cast(a_m_k_device_memory.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_device_memory.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buffer_reference.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + CLayout>( + d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c); + c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data()); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU"); } return pass; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 83795fbf6a..d3ee9fe9c6 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -4,11 +4,11 @@ #include "gemm_utils.hpp" #include "ck_tile/ops/common.hpp" -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream_config) { - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + GemmConfiguration::PERMUTE_A, + GemmConfiguration::PERMUTE_B>; - using TilePartitioner = - ck_tile::StreamKTilePartitioner; + using TilePartitioner = ck_tile:: + StreamKTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; - const auto Run = [&](const auto memory_operation) -> std::tuple { + const auto runKernel = [&](const auto memory_operation) -> std::tuple { // We create the GEMM pipeline without specifying has_hot_loop or tail_num. // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -61,39 +67,39 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, ck_tile::CShuffleEpilogueProblem>; + GemmConfiguration::NUM_WAVE_GROUPS>>; using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); ck_tile::DeviceMem workspace_data(workspace_size); workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); - dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kernel_args)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - if(s.log_level_ > 0) + if(stream_config.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' @@ -109,7 +115,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { @@ -120,45 +126,47 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, std::function preprocess = reset_data_buffers; - float ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); - ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{ave_time, num_wgs_per_tile}; + ck_tile::index_t num_wgs_per_tile = + kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; }; if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } else // We are using ck_tile::StreamKReductionStrategy::Reduction { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } } #include "run_gemm_example.inc" -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +template +int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return runGemmExampleWithLayouts( argc, argv, Row{}, Col{}, Row{}); } else @@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -template