From 771f37e4aaa58d956b678f9e50debbae22662b71 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 5 Dec 2025 14:18:30 -0800 Subject: [PATCH] Add the gfx1011 support on CK Tile with the SGPR builtin reading protection (#3350) * Finish the fixes * add the gfx1010 support macro * Fix the compilation error [ROCm/composable_kernel commit: 86a84ae61122b8ed2d2e40e45f108a8fa23d3210] --- include/ck_tile/core/config.hpp | 7 ++++ .../core/tensor/tile_scatter_gather.hpp | 3 +- .../core/tensor/tile_window_linear.hpp | 3 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 37 +++++++++++++++---- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index de97b46336..678a2fbfff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state #endif // __gfx950__ // GFX10 +#if defined(__gfx1010__) + static constexpr bool CK_TILE_ARCH_GFX1010 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1010 = false; +#endif + #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; #else @@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 97a44f38e8..7a4da64c4a 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -533,7 +533,8 @@ struct tile_scatter_gather size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using Traits = load_store_traits; diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 815c1bf158..6c84122d01 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -517,7 +517,8 @@ struct tile_window_linear size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using vector_t = typename Base::Traits::vector_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d83338fbb2..51f0f5f1b1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) constexpr index_t Bload_inst = (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), QuantGroupSize::kK * QuantGroupSize::kK), VectorLoadSize); - constexpr index_t kLdsVec = 8; + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; - constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; - constexpr index_t ds_write_inst = Aload_inst; - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); - constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma - static_for<0, nloop, 1>{}([&](auto j_inst) { - ignore = j_inst; + static_for<0, nloop, 1>{}([&](auto) { static_for<0, mfma_inst, 1>{}([&](auto i_inst) { __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) { __builtin_amdgcn_sched_group_barrier( @@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read } } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU }); });