From e5fb6909456ceeb873cbbf43fb4bf0bd9bf99f18 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Wed, 18 Feb 2026 22:59:37 +0800 Subject: [PATCH] Fix the Composable Kernel CI and versions incompatibility (#4640) ## Motivation This PR has 4 patches: 1. Fix the CI error of grouped gemm. 2. Fix the incompatibility of old linux version. 3. Fix the potential errors of flatmm. 4. Address the previous comments of abquant eight warps pipeline solution. --------- Co-authored-by: illsilin_amdeng --- .../12_reduce/reduce_blockwise_two_call.cpp | 18 ++--- example/ck_tile/18_flatmm/flatmm_basic.cpp | 1 + example/ck_tile/18_flatmm/grouped_flatmm.cpp | 1 + .../18_flatmm/run_grouped_flatmm_example.inc | 2 +- .../run_gemm_quant_example.inc | 9 ++- include/ck_tile/core/config.hpp | 2 +- include/ck_tile/host/check_err.hpp | 4 +- include/ck_tile/host/device_prop.hpp | 6 +- .../ops/epilogue/cshuffle_epilogue.hpp | 10 +-- .../ops/flatmm/kernel/flatmm_kernel.hpp | 66 ++++++++++--------- ...rsal_gemm_as_aquant_bs_bquant_cr_async.hpp | 8 ++- .../grouped_gemm/test_grouped_gemm.cpp | 5 -- 12 files changed, 67 insertions(+), 65 deletions(-) diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index 0fdc0378b7..88556acaf7 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -88,7 +88,7 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock in_1(inLengths_1); @@ -174,22 +174,22 @@ int main(int argc, char* argv[]) case 0: break; case 1: in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); break; case 2: in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); } - if(beta != 0.0f) + if(beta_ != 0.0f) for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) out.mData[i] = out_ref.mData[i]; }; @@ -200,7 +200,7 @@ int main(int argc, char* argv[]) in_1_dev.ToDevice(in_1.mData.data()); - if(beta != 0.0f) + if(beta_ != 0.0f) out_dev.ToDevice(out.mData.data()); InElementwiseOperation in_elementwise_op; @@ -246,7 +246,7 @@ int main(int argc, char* argv[]) arrOutStrides, reduceDims, static_cast(alpha), - static_cast(beta), + static_cast(beta_), in_1.mData.data(), nullptr, out_ref.mData.data(), @@ -298,7 +298,7 @@ int main(int argc, char* argv[]) arrOutStrides, reduceDims_2, static_cast(alpha), - static_cast(beta), + static_cast(beta_), in_2_dev.GetDeviceBuffer(), nullptr, out_dev.GetDeviceBuffer(), diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index af46884a90..19593a0f04 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index 780a21ba14..6b0135d370 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "flatmm_basic.hpp" diff --git a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc index 2027544709..f1889168ae 100644 --- a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc @@ -166,7 +166,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts( } ck_tile::index_t M = - std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) { + std::accumulate(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) { // round up to the multiple of BlockM return acc + (group_m + BlockM - 1) / BlockM * BlockM; }); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index dd7e1abb02..c2954f3bf5 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -35,16 +35,19 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str { static_assert(std::is_same_v); constexpr bool IS_FP8BLOCKSCALE = - QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 && + QuantMode == ck_tile::QuantType::ABQuantGrouped && (std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v); constexpr bool transpose_c = GemmConfig::TransposeC; constexpr bool eight_warps = - IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 && - (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) && +#ifdef CK_GFX950_SUPPORT + IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) && GemmConfig::K_Warp_Tile == 128; +#else + false; +#endif using ComputeDataType = std::conditional_t; diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index fed9209bad..ed102c86a8 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -73,7 +73,7 @@ #define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4 #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT -#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE +#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD #endif #define CK_TILE_FLOAT_TO_FP8_STANDARD 0 diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a2f6728316..96ec7bec4a 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -137,7 +137,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, int>::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); - auto expo = std::log2(std::abs(max_possible_num)); + auto expo = std::floor(std::log2(std::abs(max_possible_num))); double compute_error = 0; if constexpr(is_any_of::value) { @@ -158,7 +158,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, } else { - output_error = std::pow(2, expo - numeric_traits::mant) * 0.5; + output_error = std::pow(2, expo - numeric_traits::mant) * 1.0; } double midway_error = std::max(compute_error, output_error); diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index f28d7df00d..5f021d7bc5 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -65,11 +65,7 @@ inline bool is_gfx12_supported() return get_device_name() == "gfx1200" || get_device_name() == "gfx1201"; } -inline bool is_gfx95_supported() -{ - // Check if load transpose is supported. - return get_device_name() == "gfx950"; -} +inline bool is_gfx95_supported() { return get_device_name() == "gfx950"; } inline size_t get_num_cus() { diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e9a11909c7..b31f8ba02a 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -116,13 +116,13 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; -#ifdef __gfx9__ - static constexpr bool AsyncPipeline = (MWave * NWave == 8); +#ifdef __gfx95__ + static constexpr bool EightWave = (MWave * NWave == 8); #else - static constexpr bool AsyncPipeline = false; + static constexpr bool EightWave = false; #endif static constexpr index_t BlockedXDLN_PerWarp = - AsyncPipeline ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; + EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; @@ -447,7 +447,7 @@ struct CShuffleEpilogue if constexpr(is_950 || is_any_of::value || is_any_of::value) { - if constexpr(AsyncPipeline) + if constexpr(EightWave) { return tile_distribution_encoding< sequence<>, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 42dab68e91..b2b36adb1e 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -780,29 +780,31 @@ struct FlatmmKernel const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m) { - constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; - constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; + constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK; - auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale - : 1; // per-token scale + static_assert(GM != -1, + "MakeScaleMWindow should only be instantiated when scale is enabled"); + + // per-tensor (GM==0) -> Mdim = 1, stride 0 + const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM); + const index_t m_stride = (GM == 0) ? 0 : 1; + + const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); + const index_t k_stride = 0; // your original code keeps K stride 0 - // Step 1: Create tensor view const auto scale_m_view = make_naive_tensor_view( kargs.scale_m_ptr.ptr, - make_tuple(kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 - ? 1 - : (splitk_batch_offset.splitted_k / ScaleGranularityKA)), - make_tuple(scale_stride_m, 0), - number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, + make_tuple(m_dim, k_dim), + make_tuple(m_stride, k_stride), + number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, number<1>{}); - // Step 2: Create tile window + // Window extents: if GM==0, we still just broadcast from [0,*] return make_tile_window(scale_m_view, make_tuple(number{}, - number < ScaleGranularityKA == 0 - ? TilePartitioner::NPerBlock - : TilePartitioner::KPerBlock > {}), + number < (GK == 0) ? TilePartitioner::NPerBlock + : TilePartitioner::KPerBlock > {}), {block_idx_m, 0}); } @@ -811,27 +813,29 @@ struct FlatmmKernel const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_n) { - constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; - constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN; + constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK; - auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale - : 1; // per-channel scale + static_assert(GN != -1, + "MakeScaleNWindow should only be instantiated when scale is enabled"); + + // per-tensor (GN==0) -> Ndim = 1, stride 0 + const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN); + const index_t n_stride = (GN == 0) ? 0 : 1; + + const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); + const index_t k_stride = 0; - // Step 1: Create tensor view const auto scale_n_view = make_naive_tensor_view( kargs.scale_n_ptr.ptr, - make_tuple( - ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), - kargs.N / ScaleGranularityN), - make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + make_tuple(k_dim, n_dim), + make_tuple(k_stride, n_stride), + number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, number<1>{}); - // Step 2: Create tile window return make_tile_window(scale_n_view, - make_tuple(number < ScaleGranularityKB == 0 - ? TilePartitioner::MPerBlock - : TilePartitioner::KPerBlock > {}, + make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock + : TilePartitioner::KPerBlock > {}, number{}), {0, block_idx_n}); } @@ -854,8 +858,6 @@ struct FlatmmKernel MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); - const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); @@ -866,6 +868,8 @@ struct FlatmmKernel // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { + const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); + const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); if(kargs.k_batch == 1) { auto e_block_window = MakeEBlockWindow( diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp index 04d177d4d6..36c911f060 100755 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp @@ -76,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK); - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock / KWarp, BQuantGroupSize::kK); @@ -158,7 +159,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; - static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant; static_assert(std::is_same_v); @@ -364,7 +366,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase AQPickerCommon aq_picker( aq_block_tensor); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { constexpr index_t reg_offset = nIter; auto pull_from_lane = diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp index 2f37ef1bc0..0f8e25e03b 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp @@ -31,11 +31,6 @@ using KernelTypes = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, True>, std::tuple< Col, Row, Row, F16, F16, F32, F16, False>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>, std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>, std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,