From 102151ebcfc456142aead21ced58d612b445634e Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 14 May 2025 08:13:47 -0500 Subject: [PATCH] temp save --- .../moe_gemm2_xdl_mx_fp4.cpp | 40 ++++++++++++++----- ...pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp | 2 +- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 16 ++++++-- .../threadwise_tensor_slice_transfer.hpp | 11 +++-- include/ck/utility/amd_xdlops.hpp | 12 +++--- .../cpu/reference_moe_mx_gemm2.hpp | 4 -- 6 files changed, 57 insertions(+), 28 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 230a59df40..5de7ad59fe 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // clang-format on #else -static constexpr ck::index_t MPerBlock = 16; +static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -163,14 +163,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 64, - MPerBlock, 16, 128, + ScaleBlockSize, 256, + MPerBlock, 128, 128, 32, 32, 16, 16, - 1, 1, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>, + 8, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on #endif @@ -183,14 +183,14 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 2; - constexpr ck::index_t valid_tile_num = 2; + constexpr ck::index_t sorted_tile_num = 8; + constexpr ck::index_t valid_tile_num = 8; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; ck::index_t K = 4096; - ck::index_t experts = 2; + ck::index_t experts = 8; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -341,6 +341,24 @@ int main(int argc, char* argv[]) d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -378,7 +396,7 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; -#if 1 +#if 0 printf("a0_t_k_k:\n"); for(int t = 0; t < tokens; ++t) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp index a2b8318512..9e12b0420d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp @@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp( - a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets); + a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets); // B scale load auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 41665e7d45..97b0e13dda 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -529,8 +529,8 @@ struct ThreadwiseTensorSliceTransfer_v2_gather // loop over tensor and copy constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate - static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeat + static_for<0, KRepeat, 1>{}([&](auto k0) { // KRepeat constexpr auto current_dst_origin = to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0); MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0)); @@ -584,9 +584,14 @@ struct ThreadwiseTensorSliceTransfer_v2_gather src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); } + + MoveSrcSliceWindow( + src_desc, + make_multi_index( + 0, 4, 0)); // hacky fix: 4 means xdlops_gemm.KPerXdlops / ScaleBlockSize }); }); - MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0)); + MoveSrcSliceWindow(src_desc, make_multi_index(0, -(KRepeat * 4), 0)); }); // printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n", diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 5f5d6e7ffc..ad383850d6 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; -#if 0 +#if 1 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, @@ -788,9 +788,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> "v"(scale_b)); #endif -#if 1 +#if 0 printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x," - "B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, " + "B:%08x, %08x, %08x, %08x, a_scale: %.f, b_scale: %.f, " "reg_c: %f, %f, %f, %f\n", blockIdx.x, blockIdx.y, @@ -803,8 +803,10 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> bit_cast(arg_b[1]), bit_cast(arg_b[2]), bit_cast(arg_b[3]), - *(reinterpret_cast(&(scale_a))), - *(reinterpret_cast(&(scale_b))), + // *(reinterpret_cast(&(scale_a))), + // *(reinterpret_cast(&(scale_b))), + type_convert(scale_a), + type_convert(scale_b), reg_c.template AsType()[Number<0>{}], reg_c.template AsType()[Number<1>{}], reg_c.template AsType()[Number<2>{}], diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp index 69d20ef5d5..e02f9799cf 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp @@ -89,10 +89,6 @@ struct ReferenceMoeMXGemm2 : public device::BaseOperator auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1]; - if(m == 0 && n == 0) - { - printf("SCALE_BLOCK: %d\n", SCALE_BLOCK); - } AccDataType v_acc{0}; ComputeTypeA v_a{0}; ComputeTypeB v_b{0};