16x16 function merged to moe

This commit is contained in:
OscarXu
2025-05-08 14:33:31 +08:00
parent 8021128572
commit 62877fa074
4 changed files with 458 additions and 691 deletions

View File

@@ -197,96 +197,28 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
constexpr auto num_pk_fma_per_kscaleblock = MPerXDL == 16 ? 2 : 8;
constexpr auto num_mfma_per_kscaleblock =
MPerXDL == 16 ? KScaleBlock / 32 : KScaleBlock / 16;
#if 0
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
/* Judging issue v_pk_fma */
if constexpr((i + 1) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) %
num_mfma_per_kscaleblock ==
0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
});
#elif 1 // v_mul occured too early causing vmcnt stall
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
/* Judging issue v_pk_fma */
if constexpr((i + 1) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
});
// A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) %
num_mfma_per_kscaleblock ==
0)
{
__builtin_amdgcn_sched_group_barrier(
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
}
});
#endif
}
template <bool HasMainLoop,
@@ -426,12 +358,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
// printf("blockIdx.y = %d, blockIdx.x = %d, threadIdx.x = %d, c_scale_thread_buf = <%f>\n",
// blockIdx.y,
// blockIdx.x,
// threadIdx.x,
// c_scale_thread_buf[Number<0>{}]);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
@@ -583,6 +509,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
@@ -600,19 +542,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
@@ -638,8 +567,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);

View File

@@ -401,16 +401,16 @@ struct DeviceMoeGemmBlockScale
};
#endif
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
4 * (1 + GridwiseGemm::NWave);
constexpr auto estimated_reg_b =
NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
constexpr auto estimated_reg_c =
MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
constexpr auto estimated_reg_total =
estimated_reg_a + estimated_reg_b + estimated_reg_c;
// constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
// 4 * (1 + GridwiseGemm::NWave);
// constexpr auto estimated_reg_b =
// NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
// constexpr auto estimated_reg_c =
// MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
// constexpr auto estimated_reg_total =
// estimated_reg_a + estimated_reg_b + estimated_reg_c;
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr index_t minimum_occupancy = 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
@@ -704,7 +704,7 @@ struct DeviceMoeGemmBlockScale
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
index_t KBatch,
// index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
@@ -727,7 +727,7 @@ struct DeviceMoeGemmBlockScale
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
1, //KBatch,
a_element_op,
b_element_op,
c_element_op);
@@ -749,7 +749,9 @@ struct DeviceMoeGemmBlockScale
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"}};
// clang-format off
str << "DeviceMoeGEmm"

View File

@@ -187,9 +187,10 @@ struct GridwiseMoeGemmBlockScale
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
// static constexpr index_t NumTokens = 1;
@@ -249,7 +250,7 @@ struct GridwiseMoeGemmBlockScale
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -391,7 +392,7 @@ struct GridwiseMoeGemmBlockScale
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1334,7 +1335,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1946,7 +1947,7 @@ struct GridwiseMoeGemmBlockScale
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2076,6 +2077,7 @@ struct GridwiseMoeGemmBlockScale
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
@@ -2371,6 +2373,45 @@ struct GridwiseMoeGemmBlockScale
I0,
cde_lds_and_global_step);
}
// // print C
// printf("tid: %d, blkid: %d, "
// "c_thread_buf = <%1.f, %1.f, %1.f>\n "
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,"
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n"
// , get_thread_local_1d_id(), block_m_id,
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<0>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<1>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<2>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<3>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<4>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<5>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<6>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<7>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<8>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<9>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<10>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<11>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<12>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<13>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<14>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<3>{}]);
});
}
}