From 676966419745fb95a196d3905830ebf0cab28020 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Fri, 31 Oct 2025 05:01:10 -0500 Subject: [PATCH] add printf info --- .../moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp | 40 ++++++++++++++----- ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 26 ++++++++++-- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index 05ac1aab18..2133a7702b 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -177,7 +177,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 -static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MPerBlock = 64; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -185,14 +185,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 64, - MPerBlock, 128, KPerBlock, + ScaleBlockSize, 128, + MPerBlock, 64, KPerBlock, 16, 16, 16, 16, - 2, 8, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - 2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>, + 4, 2, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 8, 1, 16>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -210,10 +210,10 @@ int main(int argc, char* argv[]) ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; - ck::index_t K = 4096; - ck::index_t experts = 256; - ck::index_t tokens = 208; - ck::index_t topk = 8; + ck::index_t K = 256; + ck::index_t experts = 8; + ck::index_t tokens = 4; + ck::index_t topk = 2; if(argc == 1) { @@ -414,6 +414,24 @@ int main(int argc, char* argv[]) DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); +#if 1 + printf("a0_t_k_k:\n"); + + for(int t=0;t(&a0_t_k_k(t,tk,k)))); + k += 8; + } + printf("\n"); + } + printf("\n"); + } +#endif + // A scale sorted for(int i = 0; i < sorted_size; i++) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp index d36b5982a2..18ffee5af9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -626,7 +626,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * @@ -772,11 +774,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{})); }); }); - __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); - block_sync_lds(); - // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; }); + + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * @@ -797,7 +800,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1{}([&](auto m0) { constexpr auto im_major = m0 / MXdlPack; @@ -943,6 +945,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1(), b_scale_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); + + #if 1 + printf("blkIdx: %u, blkIdy: %u, tidx: %u, im_minor: %d, in_minor: " + "%d, ik_minor: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, " + "0x%08x>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + im_minor, + in_minor, + ik_minor, + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<3>{}])))); + #endif }); }); });