From 8d51a4ae96a03edebe6fbf63e9ddbee46062f1c7 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 7 May 2025 09:40:47 +0000 Subject: [PATCH] wip2 --- .../67_gemm_microscaling/gemm_mx_common.hpp | 26 +++++++++---------- example/67_gemm_microscaling/gemm_mx_fp8.cpp | 22 ++++++++-------- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 23 +++++++++------- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 8 +++--- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 7 ++++- include/ck/utility/type_convert.hpp | 5 ++++ .../cpu/reference_mx_gemm.hpp | 23 +++++++++++++--- 7 files changed, 72 insertions(+), 42 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 7e28919a2a..87392002c6 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -200,10 +200,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c switch(config.init_method) { case 0: // Initializations for development and debugging - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); if(config.verbosity > 0) { std::cout << "Init A = {1}" << std::endl; @@ -347,16 +347,16 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Comparing results..." << std::endl; } - if(config.init_method == 0) - { - auto expected = static_cast(K); - auto computed = type_convert(c_m_n_device_result(1, 12)); + // if(config.init_method == 0) + // { + // auto expected = static_cast(K); + // auto computed = type_convert(c_m_n_device_result(1, 12)); - res_verified = res_verified && std::abs(expected - computed) <= 0.0f; - std::cout << "\nExpected vs Computed: " << expected << " vs " << computed - << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl - << std::endl; - } + // res_verified = res_verified && std::abs(expected - computed) <= 0.0f; + // std::cout << "\nExpected vs Computed: " << expected << " vs " << computed + // << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl + // << std::endl; + // } res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp index e40f9d0391..9d25fa015c 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -45,24 +45,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle CElementOp, // CElementwiseOperation GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock + 64, // BlockSize: Thread block size + 16, // MPerBlock + 16, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim @@ -71,8 +71,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle false, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer ADataType, // ComputeTypeA diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp index 14c2a11347..aed7dcd08b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -344,13 +344,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}), - // a_block_buf, - // a_thread_desc_, - // make_tuple(m0, I0, k, Number{}), - // a_thread_buf); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -406,13 +406,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx()(s) = b_scale_thread_buf[Number{}]; }); - CK_TILE_PRINT(); - CK_TILE_PRINT(); + // CK_TILE_PRINT(); + // CK_TILE_PRINT(); using mfma_input_type_a = typename vector_type::type; // mfma input type = pk_f4_t, 32 - CK_TILE_PRINT(); + // CK_TILE_PRINT(); using mfma_input_type_b = typename vector_type::type; @@ -538,6 +538,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx(&a_scale_thread_buf[I0]), + *reinterpret_cast(&b_scale_thread_buf[I0])); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index b6b6f92135..c9280373a2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -167,15 +167,16 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // // Should be a multiple of k_per_blk. // TODO: Move this to blockwise pipeline base - static constexpr index_t KPack = // = num of pk_f4 - math::max(lcm_AK1_BK1, // num of pk_f4 + // KPack in packed data types for pk A/B + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk / - 2); // num of f4 + 2); using ThisThreadBlock = ThisThreadBlock; @@ -1567,6 +1568,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // shuffle C and write out { + // printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]); static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 15b2661c8d..e4e50a9325 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1140,6 +1140,11 @@ struct MfmaSelector { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } template <> constexpr auto GetMfma() @@ -1443,7 +1448,7 @@ struct XdlopsGemm const ScaleB& b_scale_thread, FloatC& p_c_thread) const { - static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + static_for<0, KPack * 2 / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) { mfma_instr.template run( diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 04ae046ac8..3fe6dad194 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1572,6 +1572,11 @@ inline __host__ __device__ f4x2_t type_convert(float2_t x) return f4_convert_rne(x); #endif } +template <> +inline __host__ __device__ f4x2_pk_t type_convert(float2_t x) +{ + return static_cast(type_convert(x)); +} // convert vector of 32 fp32 to vector of 32 fp4 template <> diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index e8fdcf1acd..d17c532208 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -84,6 +84,7 @@ struct ReferenceMXGemm : public device::BaseOperator const auto N = arg.b_k_n_.mDesc.GetLengths()[1]; const auto K = arg.a_m_k_.mDesc.GetLengths()[1]; const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1]; + printf("K: %d\n", K); for(size_t m = 0; m < M; m++) { @@ -95,15 +96,29 @@ struct ReferenceMXGemm : public device::BaseOperator if(k % 2 == 1) a_m_k_scaled(m, k) = type_convert( - f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * + f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{}))) * type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); else a_m_k_scaled(m, k) = type_convert( - f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * + f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))) * type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if(m == 0) + { + printf("a_m_k_scaled(%zu, %zu): %f = %f * %f\n", + m, + k, + a_m_k_scaled(m, k), + k % 2 == 1 + ? type_convert(f4_t( + arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{}))) + : type_convert(f4_t( + arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))), + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK))); + } } else { @@ -124,13 +139,13 @@ struct ReferenceMXGemm : public device::BaseOperator if(k % 2 == 1) b_k_n_scaled(k, n) = type_convert( - f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * + f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<1>{}))) * type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); else b_k_n_scaled(k, n) = type_convert( - f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * + f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<0>{}))) * type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); }