From b1faa0c1c5d96db5226fed3f63bdf6d3af7109b2 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:56:40 -0800 Subject: [PATCH] [CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204) gemm shape should be used from the pipeline instead (where it gets from a problem description struct) [ROCm/composable_kernel commit: a3a4eb12bdfc1b2de642f39af458d03aef3a3d60] --- .../ops/flatmm/kernel/flatmm_kernel.hpp | 4 ++-- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 8 +++---- .../kernel/streamk_gemm_tile_partitioner.hpp | 7 +++---- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 11 +++++----- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 21 +++++++++---------- ...ouped_convolution_backward_data_kernel.hpp | 4 ++-- ...ped_convolution_backward_weight_kernel.hpp | 6 +++--- .../grouped_convolution_forward_kernel.hpp | 4 ++-- 9 files changed, 31 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 7523acc080..d3ecbefd91 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -363,8 +363,8 @@ struct FlatmmKernel template __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{}); - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{}); + constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 411cfe81ed..8a9aa3cdd3 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -369,7 +369,7 @@ struct MoeFlatmmKernel template __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 673f5abc34..fc85c4dcdf 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -386,11 +386,9 @@ template struct StreamKTilePartitioner { - using BlockGemmShape = BlockGemmShapeType; - - static constexpr uint32_t MPerBlock = BlockGemmShape::kM; - static constexpr uint32_t NPerBlock = BlockGemmShape::kN; - static constexpr uint32_t KPerBlock = BlockGemmShape::kK; + static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM; + static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN; + static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK; CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp index 996ef5a7ef..f32f8b681b 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -22,11 +22,10 @@ template struct StreamKTilePartitionerBase { - using BlockGemmShape = BlockGemmShapeType; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t MPerBlock = BlockGemmShapeType::kM; + static constexpr index_t NPerBlock = BlockGemmShapeType::kN; + static constexpr index_t KPerBlock = BlockGemmShapeType::kK; static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index e77355ed3d..2aac894a46 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -325,7 +325,7 @@ struct UniversalGemmKernel { __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); @@ -584,7 +584,7 @@ struct UniversalGemmKernel const KernelArgs& kargs, const index_t k_size) { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); const auto& as_tensor_view = generate_tuple( [&](auto i) { @@ -617,7 +617,7 @@ struct UniversalGemmKernel using BiDataType = remove_cvref_t>; if constexpr(std::is_same_v) { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = k_size / K1; @@ -649,7 +649,7 @@ struct UniversalGemmKernel } else { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = k_size / K1; @@ -675,8 +675,7 @@ struct UniversalGemmKernel { index_t kFlatK = GemmPipeline::BlockGemmShape::flatKPerWarp * - (k_size / - TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; return make_naive_tensor_view( diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 15d2727f3b..6c90d5c1a6 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -276,7 +276,7 @@ struct QuantGemmKernel __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2); + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); @@ -487,7 +487,7 @@ struct QuantGemmKernel const SplitKBatchOffset& splitk_batch_offset) { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -537,7 +537,7 @@ struct QuantGemmKernel const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; const auto wave_tile_size = - TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; + GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( @@ -597,7 +597,7 @@ struct QuantGemmKernel const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = splitk_batch_offset.splitted_k / K1; @@ -627,7 +627,7 @@ struct QuantGemmKernel } else { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = splitk_batch_offset.splitted_k / K1; @@ -649,10 +649,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleB) { - index_t kFlatK = - GemmPipeline::flatKPerWarp * - (splitk_batch_offset.splitted_k / - TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatK = GemmPipeline::flatKPerWarp * + (splitk_batch_offset.splitted_k / + GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; return make_naive_tensor_view( @@ -837,7 +836,7 @@ struct QuantGemmKernel static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0); + constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; constexpr auto tile_window_width = ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); @@ -880,7 +879,7 @@ struct QuantGemmKernel b_pad_view, make_tuple(number{}, number{}), - {static_cast(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0}); + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); } else { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 7b8cdb3792..b1ed80b5ea 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -724,8 +724,8 @@ struct GroupedConvolutionBackwardDataKernel const GroupedConvBwdDataKernelArgsSpecialized& kargs, const index_t group_id) { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); const auto& a_tensor_view = [&]() { return make_tensor_view( a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 2eb4f2dfd1..3407c67ad1 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -464,7 +464,7 @@ struct GroupedConvolutionBackwardWeightKernel __device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1); @@ -646,8 +646,8 @@ struct GroupedConvolutionBackwardWeightKernel WeiDataType* c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); const auto& a_tensor_view = [&]() { return make_tensor_view(a_ptr, kargs.a_grid_desc_k_m); // A: out diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 6de331fe6d..4eccd1eebb 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -745,8 +745,8 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc) { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); const auto& a_tensor_view = [&]() { return make_tensor_view(a_ptr, a_desc); }();