From 9b65e9ec43498c79c873dd7e40d3a1b02435a69a Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Sat, 13 Sep 2025 21:54:08 -0600 Subject: [PATCH] [CK TILE GEMM] set correct value to TiledMMAPermuteN_ (#2839) - TiledMMAPermuteN_ should be set to true when config if GemmConfigPreshufflePrefill [ROCm/composable_kernel commit: e5d73da2da96e7de050957dc6453c8347b492baa] --- .../gemm_weight_preshuffle_invoker.hpp | 5 ++++- .../ops/epilogue/cshuffle_epilogue.hpp | 22 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index 9de1a018db..b47dd8d8a7 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -106,7 +106,10 @@ struct WeightPreshuffleInvoker GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation, - GemmConfig::NumWaveGroups>>; + GemmConfig::NumWaveGroups, + false, + 1, + true>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index ed73f7e9f4..628af0e0b3 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -12,6 +12,22 @@ namespace ck_tile { +template +concept HasDataType = requires { typename T::DataType; }; + +template +struct GetDataType +{ + using type = float; +}; + +template + requires HasDataType +struct GetDataType +{ + using type = typename T::DataType; // Use T::ScaleN::DataType +}; + template ::value && !std::is_same::value; // Tiles to hold row/col scales when present - using SMType = - std::conditional_t, float>; - using SNType = - std::conditional_t, float>; + using SMType = typename GetDataType>::type; + using SNType = typename GetDataType>::type; auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution);