diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 7d0cd6542f..429b21c060 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -24,19 +24,20 @@ template using S = ck::Sequence; -using F4 = ck::f4x2_pk_t; -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using XDataType = ck::e8m0_bexp_t; +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using A0DataType = F4; -using A1DataType = XDataType; +using A1DataType = XPackedDataType; using B0DataType = F4; -using B1DataType = XDataType; +using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; @@ -170,7 +171,9 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 #if 0 static constexpr ck::index_t MPerBlock = 128; @@ -213,14 +216,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 128, 128, - 32, 32, + MPerBlock, 256, KPerBlock, 16, 16, - 8, 2, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 16, 16, + 8, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on #endif @@ -328,22 +331,22 @@ int main(int argc, char* argv[]) expert_ids.savetxt("expert_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor a1_t_k_k( + Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b1_e_n_k( + Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); // B preshuffle Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); // A, B Scale preshuffle - Tensor a_scale_sorted(HostTensorDescriptor( + Tensor a_scale_sorted(HostTensorDescriptor( {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor a_scale_preshuffled(HostTensorDescriptor( + Tensor a_scale_preshuffled(HostTensorDescriptor( {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b_scale_preshuffled( + Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {N * Scale_Stride_BN, 1, Scale_Stride_BN})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); @@ -364,50 +367,50 @@ int main(int argc, char* argv[]) case 1: a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 1}); d2_e_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); break; case 2: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 4: a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 5: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 6: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * @@ -415,35 +418,37 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2); - DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); - DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); // A scale sorted for(int i = 0; i < sorted_size; i++) { - int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF; - int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) { - if(tokenid = = tokens) + if(token_id == tokens) { a_scale_sorted(i, k) = 0; } else { - a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k); + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); } } } - preShuffleBuffer>( - a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K); - preShuffleBuffer>( - b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K); + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -614,9 +619,9 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm2 @@ -170,7 +176,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base return make_tuple(c_thread_m, c_thread_n); } - using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. @@ -190,8 +196,8 @@ struct BlockwiseGemmXdlops_mx_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), - Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(), + Tuple5 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), @@ -327,49 +333,63 @@ struct BlockwiseGemmXdlops_mx_pipeline_base __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } - static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; - static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k; protected: // M1, N1 as double buffer index // Read buffer + Compute buffer - // A[M0, M1, M2, KPack] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( - make_tuple(Number{}, I1, Number{}, Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - I1)); + // A[M0, M1, M2, M3, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + I1)); - // B[N0, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( - make_tuple(Number{}, I1, Number{}, Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - I1)); + // B[N0, N1, N2, N3, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + I1)); // C[M, N, NumRegXdlops] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, 1, 1, 1, KThreadChunk>, + Sequence<0, 1, 2, 3, 4>, + 4, A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, 1, 1, 1, KThreadChunk>, + Sequence<0, 1, 2, 3, 4>, + 4, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index 94e034654d..f899c223b9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -495,7 +495,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< index_t i = 0; do { - auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf, auto a_buf) { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { // Prefetch a_scales to buf 1 a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, @@ -683,8 +683,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< __builtin_amdgcn_sched_barrier(0); }; // LoopFunc - LoopFunc(I0, I1, I0); - LoopFunc(I1, I0, I1); + LoopFunc(I0, I1); + LoopFunc(I1, I0); i += 2; } while(i < (num_loop - 2)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp index 30d2d6fcc9..baa038e710 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -277,16 +277,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = + make_static_buffer( + a_scale_thread_desc_copy.GetElementSpaceSize()); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_copy, + make_tuple(I0, I0), + a_scale_thread_buf_copy); + + a_scale_thread_buf(I0)(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, ScalesPerKBlockSize, 0)); + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); // Prefetch b_scales to buf 0 static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -329,15 +348,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = + make_static_buffer( + a_scale_thread_desc_copy.GetElementSpaceSize()); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_copy, + make_tuple(I0, I0), + a_scale_thread_buf_copy); + + a_scale_thread_buf(I1)(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, ScalesPerKBlockSize, 0)); + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); // Prefetch b_scales to buf 1 static_for<0, NRepeat, 1>{}([&](auto n0) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index bc63f57b93..d526373707 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -138,50 +138,56 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 - __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + template + __host__ __device__ static constexpr auto + MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_M3_K&) { - constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); - constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); - constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t M0 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<2>{}); + constexpr index_t M3 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<3>{}); constexpr index_t K2 = KPack; constexpr index_t K1 = 64 / NPerXDL; constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( - TileDesc_M0_M1_M2_K{}, + TileDesc_M0_M1_M2_M3_K{}, make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4, 5, 6>{})); } - static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = - MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + static constexpr auto a_block_desc_m0_m1_m2_m3_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_m3_k); static constexpr auto ScalesPerKBlockSize = KPerBlock / ScaleBlockSize; // How many mx-vectors per K block @@ -375,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}> b_thread_bufs; - constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -388,7 +394,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, ScalesPerKBlockSize, 0)); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); // Prefetch b_scales 1 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); - b_scale_thread_bufs(I0)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); }); b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); }); + // restore col id and advance to the next set of scales - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); // vmem->vgpr-> lds0 @@ -444,28 +453,37 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k0_k1_k2, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); }); }); + // Initialize C + c_thread_buf.Clear(); + // main body if constexpr(HasMainLoop) { @@ -473,51 +491,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0)); + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); - // Prefetch b_scales 2 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); + // Prefetch b_scales 1 + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); - b_scale_thread_bufs(local_read_buf)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); }); b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); }); + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); // Local prefill A2 block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(scale_mem_buf)); // Global prefetch A1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -526,91 +546,127 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; - vector_type + vector_type b_scale_thread_vec; // Pack scale_thread_buf into scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[mfma_reg_buf] - [Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[mfma_reg_buf] - [Number{}]; + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); // Local prefetch A2 block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * - xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k0_k1_k2, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); }); }); @@ -618,8 +674,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); - // Prefetch b_scales 2 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); - b_scale_thread_bufs(I1)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); }); b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); }); // Local prefill A2 @@ -668,212 +723,299 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; + a_scale_thread_bufs(I0)[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat + vector_type a_thread_vec; + vector_type b_thread_vec; - // Local prefetch A2 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); }); }); }); - // A2 * B2 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + // Local prefetch A2 + block_sync_lds(); - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k0_k1_k2, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); }); + }); + }); + // A2 * B2 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I1][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I1][Number{}]; + a_scale_thread_bufs(I1)[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); } else if constexpr(TailNum == TailNumber::Odd) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; + a_scale_thread_bufs(I0)[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); } } // TODO: make this field protected when a_scale_thread_copy_ is moved // here static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + make_tuple(Number{}, + Number{}, + Number{})); // TODO: make this field protected when b_scale_thread_copy_ is moved // here static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from b_scale_grid to b_scale_thread_buf - static constexpr auto b_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + make_tuple(Number{}, + Number{}, + Number{})); protected: static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -884,7 +1026,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3, pk_i4_t> || + is_same_v, f4x2_pk_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) + return 2; + else + return 1; + }(); + static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; using mfma_selector = MfmaSelector; - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t> || - is_same_v, f4x2_pk_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t> || - is_same_v, f4x2_pk_t>) - return 2; - else - return 1; - }(); - __host__ static auto CalculateGridSize(index_t M, index_t N) { const index_t nblock = math::integer_divide_ceil(N, NPerBlock); @@ -317,7 +317,11 @@ struct GridwiseMoeGemmMX return math::integer_divide_ceil(N, NPerBlock); } - template + template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); @@ -326,10 +330,12 @@ struct GridwiseMoeGemmMX return transform_tensor_descriptor( TileDesc_K0_MN_K1{}, make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); } __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( @@ -513,16 +519,18 @@ struct GridwiseMoeGemmMX template __host__ __device__ static constexpr auto - MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) { - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); } template __host__ __device__ static constexpr auto - MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) { - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); } template @@ -789,7 +797,7 @@ struct GridwiseMoeGemmMX else if constexpr(is_same_v) { a_scale_k_split_offset = - k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA; + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; } // Calculate B scale offset @@ -939,8 +947,11 @@ struct GridwiseMoeGemmMX __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - return make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); } __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() @@ -969,9 +980,9 @@ struct GridwiseMoeGemmMX AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, @@ -1395,21 +1406,26 @@ struct GridwiseMoeGemmMX auto b_block_buf = make_static_buffer( b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1511,15 +1527,17 @@ struct GridwiseMoeGemmMX BScaleDataType, BScaleDataType, decltype(b_scale_grid_desc_bn_ak), - decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy), - Sequence<1, 1>, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector true>( b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, @@ -1958,17 +1976,17 @@ struct GridwiseMoeGemmMX problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( - make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, - math::integer_divide_ceil(problem.K, ScaleBlockSize) / - ScalesPerXdlopsRunPerThread, - ScalesPerXdlopsRunPerThread), - make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), - ScalesPerXdlopsRunPerThread, - 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1)); + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -2145,62 +2163,48 @@ struct GridwiseMoeGemmMX const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / - mfma.selected_mfma.num_threads_per_blk; - - auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + auto a_thread_offset_m = waveId_m; // get each thread's offset int the scale tensor const index_t token_scale_pos = block_m_id * MPerBlock; if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray scale_gather_offsets; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { - const index_t fused_token = - p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m]; - index_t token_offset = fused_token & 0xffffff; - if constexpr(!IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scale_gather_offsets(m0) = - token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize); - }); - - auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather< + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< AScaleDataType, AScaleDataType, decltype(a_scale_grid_desc_am_ak), decltype(BlockwiseGemmPipe::a_scale_thread_desc), - Sequence<1, 1, 1>, // SliceLengths - Sequence<0, 1, 2>, // DimAccessOrder - 2, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true, - MXdlPerWave, - KRepeat>( - a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets); + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); // B scale load - auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + auto b_thread_offset_n = waveId_n; - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, - true>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); if constexpr(IsInputGemm) { @@ -2231,15 +2235,18 @@ struct GridwiseMoeGemmMX BScaleDataType, BScaleDataType, decltype(b_scale_grid_desc_bn_ak), - decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy), - Sequence<1, 1>, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector true>( b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,