Add the gfx1011 support on CK Tile with the SGPR builtin reading protection (#3350)

* Finish the fixes

* add the gfx1010 support macro

* Fix the compilation error
This commit is contained in:
Thomas Ning
2025-12-05 14:18:30 -08:00
committed by GitHub
parent 6b1bceca7b
commit 86a84ae611
4 changed files with 41 additions and 9 deletions

View File

@@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state
#endif // __gfx950__
// GFX10
#if defined(__gfx1010__)
static constexpr bool CK_TILE_ARCH_GFX1010 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
#else
@@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -533,7 +533,8 @@ struct tile_scatter_gather
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
using Traits = load_store_traits;

View File

@@ -517,7 +517,8 @@ struct tile_window_linear
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
using vector_t = typename Base::Traits::vector_t;

View File

@@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
// Estimated number of VMEM vector loads for A per block:
// total A bytes / (threads per block * vector width)
constexpr index_t Aload_inst =
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
// Estimated number of VMEM vector loads for B per block:
// total B bytes / (threads per block * vector width)
constexpr index_t Bload_inst =
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
// Estimated number of VMEM loads for B's quant data (e.g. scales / zp).
// First ceil-divide by quant group size (how many elements share one scale),
// then by vector width to get an approximate number of vector loads.
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
QuantGroupSize::kK * QuantGroupSize::kK),
VectorLoadSize);
constexpr index_t kLdsVec = 8;
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
constexpr index_t kLdsInstCycle = 8;
// Total VMEM load instructions (A + B + quant data)
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
constexpr index_t ds_read_inst = kMPerBlock / kLdsVec;
constexpr index_t ds_write_inst = Aload_inst;
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
// Approximate number of LDS reads per block
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
// Approximate number of LDS writes per block
// (e.g., writing A from VMEM into LDS once per A load)
constexpr index_t ds_write_inst = Aload_inst;
// Number of MFMA instructions per wave for one block tile:
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
// How often (in MFMA units) we should insert DS (LDS) operations.
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
// How often (in MFMA units) we should insert VMEM buffer loads.
// buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load
// is assumed to cover at most 4 MFMA instructions.
constexpr index_t buffer_load_rep =
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
static_for<0, nloop, 1>{}([&](auto j_inst) {
ignore = j_inst;
static_for<0, nloop, 1>{}([&](auto) {
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
// Insert LDS read/write groups periodically based on ds_rep.
// The % pattern staggers READ and WRITE so they don't collapse
// into the same cycle in the model.
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
{
__builtin_amdgcn_sched_group_barrier(
@@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
}
}
// Always mark some VALU work in the loop to reflect auxiliary scalar
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
});
});