From d1d56e89efead039bdb39b4f8d812b3fad0bb6b0 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 26 May 2025 09:29:36 +0000 Subject: [PATCH] fix the correctness issue --- .../gemm_mx_bpreshuffle_common.hpp | 56 ++++--- .../gemm_mx_fp4_bpreshuffle.cpp | 14 +- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 20 +-- ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 145 ++++++++++-------- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 11 +- 5 files changed, 133 insertions(+), 113 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp b/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp index 427b68f3c6..8071014e25 100644 --- a/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp @@ -253,6 +253,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c using AScaleLayout = Row; using BScaleLayout = Col; + const auto APackedSize = []() { + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) + return 2; + else + return 1; + }(); + + const auto BPackedSize = []() { + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) + return 2; + else + return 1; + }(); + auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); @@ -362,26 +378,12 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c int NPerXdl = 16; // Fixed 16 preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl); #endif - printf("a:\n"); - for(ck::index_t i = 0; i < M; i++) - { - for(ck::index_t j = 0; j < K; j += 2) - { - printf("%02x ", *reinterpret_cast(&a_m_k(i, j))); - if(j % 32 == 31) - { - printf("\n"); - } - } - printf("\n"); - } - - // printf("b:\n"); - // for(ck::index_t i = 0; i < N; i++) + // printf("a:\n"); + // for(ck::index_t i = 0; i < M; i++) // { // for(ck::index_t j = 0; j < K; j += 2) // { - // printf("%02x ", *reinterpret_cast(&b_k_n(j, i))); + // printf("%02x ", *reinterpret_cast(&a_m_k(i, j))); // if(j % 32 == 31) // { // printf("\n"); @@ -389,6 +391,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // } // printf("\n"); // } + + // printf("b:\n"); + // for(ck::index_t i = 0; i < N; i++) + // { + // for(ck::index_t j = 0; j < K; j += 2) + // { + // printf("%02x ", *reinterpret_cast(&b_preshuffled(j, i))); + // if(j % 128 == 126) + // { + // printf("\n"); + // } + // } + // // printf("\n"); + // } // printf("b_scale:\n"); // for(ck::index_t i = 0; i < N; i++) // { @@ -547,9 +563,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // partial sums(K/ScaleBlockSize)] // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + std::size_t num_btype = sizeof(ADataType) * M * K / APackedSize+ + sizeof(BDataType) * K * N / BPackedSize+ sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp index 13b4bc69b7..960e5e23d2 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -49,24 +49,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle CElementOp, // CElementwiseOperation GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size - 64, // BlockSize: Thread block size - 64, // MPerBlock - 64, // NPerBlock + 256, // BlockSize: Thread block size + 128, // MPerBlock + 256, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL - 4, // MXdlPerWave + 8, // MXdlPerWave 4, // NXdlPerWave - S<8, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM - S<8, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim @@ -75,7 +75,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle true, // BBlockLdsExtraN 2, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle - S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index 9b698c6564..430014d483 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -382,29 +382,19 @@ struct BlockwiseGemmXdlops_mx_pipeline_base // Read buffer + Compute buffer // A[M0, M1, M2, KPack] static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor(make_tuple(Number{}, + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, Number{}, Number{}, - Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - I1)); + Number{})); // B[N0, N1, N2, KPack] static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor(make_tuple(Number{}, + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, - Number{}, Number{}, - Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - I1)); + Number{}, + Number{})); // C[M, N, NumRegXdlops] static constexpr auto c_thread_desc_ = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index f6b232bc39..56a83b9f66 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -446,30 +446,82 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(a_thread_buf(Number<0>{}))), - *reinterpret_cast(&(a_thread_buf(Number<1>{}))), - *reinterpret_cast(&(a_thread_buf(Number<2>{}))), - *reinterpret_cast(&(a_thread_buf(Number<3>{}))), - *reinterpret_cast(&(a_thread_buf(Number<4>{}))), - *reinterpret_cast(&(a_thread_buf(Number<5>{}))), - *reinterpret_cast(&(a_thread_buf(Number<6>{}))), - *reinterpret_cast(&(a_thread_buf(Number<7>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<7>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<8+7>{}))), get_thread_local_1d_id(), - *reinterpret_cast(&(a_thread_buf(Number<8+0>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+1>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+2>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+3>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+4>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+5>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+6>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+7>{}))) + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+7>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<16+8+7>{}))), + get_thread_local_1d_id(), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+7>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+8+7>{}))), + get_thread_local_1d_id(), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+7>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+0>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+1>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+2>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+3>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+4>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+5>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+6>{}))), + *reinterpret_cast(&(b_thread_bufs(I0)(Number<32+16+8+7>{}))) ); - +#endif // Initialize C c_thread_buf.Clear(); __builtin_amdgcn_sched_barrier(0); @@ -996,11 +1048,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(a_thread_vec.template AsType()(Number<0>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<1>{}))), @@ -1010,7 +1062,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(a_thread_vec.template AsType()(Number<5>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<6>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<7>{}))), - get_thread_local_1d_id(), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+0>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+1>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+2>{}))), @@ -1018,13 +1069,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(a_thread_vec.template AsType()(Number<8+4>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+5>{}))), *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+6>{}))), - *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+7>{}))) - ); - printf("Tid: %02d, ik, im, in = %d, %d, %d\n" - "Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n" - "Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n", - get_thread_local_1d_id(), - ikxdl.value, imxdl.value, inxdl.value, + *reinterpret_cast(&(a_thread_vec.template AsType()(Number<8+7>{}))), get_thread_local_1d_id(), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<0>{}))), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<1>{}))), @@ -1034,7 +1079,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(b_thread_vec.template AsType()(Number<5>{}))), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<6>{}))), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<7>{}))), - get_thread_local_1d_id(), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<8+0>{}))), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<8+1>{}))), *reinterpret_cast(&(b_thread_vec.template AsType()(Number<8+2>{}))), @@ -1080,8 +1124,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto k) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { @@ -1112,28 +1156,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(a_thread_buf(Number<0>{}))), - *reinterpret_cast(&(a_thread_buf(Number<1>{}))), - *reinterpret_cast(&(a_thread_buf(Number<2>{}))), - *reinterpret_cast(&(a_thread_buf(Number<3>{}))), - *reinterpret_cast(&(a_thread_buf(Number<4>{}))), - *reinterpret_cast(&(a_thread_buf(Number<5>{}))), - *reinterpret_cast(&(a_thread_buf(Number<6>{}))), - *reinterpret_cast(&(a_thread_buf(Number<7>{}))), - get_thread_local_1d_id(), - *reinterpret_cast(&(a_thread_buf(Number<8+0>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+1>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+2>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+3>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+4>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+5>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+6>{}))), - *reinterpret_cast(&(a_thread_buf(Number<8+7>{}))) - ); } }); } @@ -1143,16 +1165,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}, + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, Number{}, Number{}, - Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - I1)); + Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4{}; - return make_naive_tensor_descriptor( - make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber), - make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber, - NXdlPack * K0 * NkSwizzleNumber, - K0 * NkSwizzleNumber, - NkSwizzleNumber, - I1)); + return make_naive_tensor_descriptor_packed( + make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( @@ -1903,7 +1898,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle Number{}, Number{}, Number{}>, - Sequence<1, 2, 3, 0, 4>, + Sequence<0, 1, 2, 3, 4>, 4, BBlockTransferSrcScalarPerVector, BThreadTransferSrcResetCoordinateAfterRun,