mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add the gfx1011 support on CK Tile with the SGPR builtin reading protection (#3350)
* Finish the fixes * add the gfx1010 support macro * Fix the compilation error
This commit is contained in:
@@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state
|
||||
#endif // __gfx950__
|
||||
|
||||
// GFX10
|
||||
#if defined(__gfx1010__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
#else
|
||||
@@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
|
||||
@@ -533,7 +533,8 @@ struct tile_scatter_gather
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
|
||||
@@ -517,7 +517,8 @@ struct tile_window_linear
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
|
||||
@@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
template <index_t nloop>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Estimated number of VMEM vector loads for A per block:
|
||||
// total A bytes / (threads per block * vector width)
|
||||
constexpr index_t Aload_inst =
|
||||
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
|
||||
// Estimated number of VMEM vector loads for B per block:
|
||||
// total B bytes / (threads per block * vector width)
|
||||
constexpr index_t Bload_inst =
|
||||
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
|
||||
|
||||
// Estimated number of VMEM loads for B's quant data (e.g. scales / zp).
|
||||
// First ceil-divide by quant group size (how many elements share one scale),
|
||||
// then by vector width to get an approximate number of vector loads.
|
||||
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
|
||||
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
|
||||
QuantGroupSize::kK * QuantGroupSize::kK),
|
||||
VectorLoadSize);
|
||||
constexpr index_t kLdsVec = 8;
|
||||
|
||||
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
|
||||
constexpr index_t kLdsInstCycle = 8;
|
||||
// Total VMEM load instructions (A + B + quant data)
|
||||
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsVec;
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
// Approximate number of LDS reads per block
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
|
||||
// Approximate number of LDS writes per block
|
||||
// (e.g., writing A from VMEM into LDS once per A load)
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
// Number of MFMA instructions per wave for one block tile:
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
// How often (in MFMA units) we should insert DS (LDS) operations.
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
// How often (in MFMA units) we should insert VMEM buffer loads.
|
||||
// buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load
|
||||
// is assumed to cover at most 4 MFMA instructions.
|
||||
constexpr index_t buffer_load_rep =
|
||||
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
|
||||
|
||||
static_for<0, nloop, 1>{}([&](auto j_inst) {
|
||||
ignore = j_inst;
|
||||
static_for<0, nloop, 1>{}([&](auto) {
|
||||
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
|
||||
|
||||
// Insert LDS read/write groups periodically based on ds_rep.
|
||||
// The % pattern staggers READ and WRITE so they don't collapse
|
||||
// into the same cycle in the model.
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
@@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
}
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user