moe gemm2 bpreshuffle function passed

This commit is contained in:
mtgu0705
2025-06-18 00:23:44 -05:00
parent 54c930d3b9
commit daa013bf62
3 changed files with 156 additions and 86 deletions

View File

@@ -175,7 +175,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -183,14 +183,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 64,
MPerBlock, 32, KPerBlock,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
2, 2,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
@@ -202,14 +202,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t valid_tile_num = 2;
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 2;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -351,30 +351,44 @@ int main(int argc, char* argv[])
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;

View File

@@ -965,6 +965,54 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
#if 0
printf("blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
"%d, ikxdl: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, "
"0x%08x>, "
"b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
"a_scale=0x%08x, "
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
im_minor,
in_minor,
ik_minor,
*(reinterpret_cast<const uint32_t*>(
&(a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(a_scale_thread_vec
.template AsType<AScaleDataType>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_scale_thread_vec
.template AsType<BScaleDataType>()[Number<0>{}]))),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<0>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<1>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<2>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<3>{}]));
#endif
});
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))

View File

@@ -42,12 +42,12 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
{
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -79,12 +79,12 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
{
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -885,7 +885,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
{
// contiguous in LDS
return make_naive_tensor_descriptor(
make_tuple(Number<AK0Number>{}, Number<MPerBlock>{}, AK1Number),
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
@@ -2025,17 +2025,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
problem.NPadded,
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.M / (MXdlPack * MPerXdl),
// We pad the M unconditionaly for Scale
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -2102,23 +2115,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
// Gride buffer creation
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
#if 0
printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
#endif
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto b_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::SYSTEM_NT1>(
p_b_grid + expert_id * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -2585,16 +2590,18 @@ struct GridwiseMoeGemmMX_BPreshuffle
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin =
@@ -2615,40 +2622,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());