diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index 9de1a018db..b47dd8d8a7 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -106,7 +106,10 @@ struct WeightPreshuffleInvoker GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation, - GemmConfig::NumWaveGroups>>; + GemmConfig::NumWaveGroups, + false, + 1, + true>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index ed73f7e9f4..628af0e0b3 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -12,6 +12,22 @@ namespace ck_tile { +template +concept HasDataType = requires { typename T::DataType; }; + +template +struct GetDataType +{ + using type = float; +}; + +template + requires HasDataType +struct GetDataType +{ + using type = typename T::DataType; // Use T::ScaleN::DataType +}; + template ::value && !std::is_same::value; // Tiles to hold row/col scales when present - using SMType = - std::conditional_t, float>; - using SNType = - std::conditional_t, float>; + using SMType = typename GetDataType>::type; + using SNType = typename GetDataType>::type; auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution);