[CK TILE GEMM] set correct value to TiledMMAPermuteN_ (#2839)

- TiledMMAPermuteN_ should be set to true when config if GemmConfigPreshufflePrefill
This commit is contained in:
Cong Ma
2025-09-13 21:54:08 -06:00
committed by GitHub
parent 3a51dbba85
commit e5d73da2da
2 changed files with 22 additions and 5 deletions

View File

@@ -106,7 +106,10 @@ struct WeightPreshuffleInvoker
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
GemmConfig::NumWaveGroups,
false,
1,
true>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -12,6 +12,22 @@
namespace ck_tile {
template <typename T>
concept HasDataType = requires { typename T::DataType; };
template <typename T>
struct GetDataType
{
using type = float;
};
template <typename T>
requires HasDataType<T>
struct GetDataType<T>
{
using type = typename T::DataType; // Use T::ScaleN::DataType
};
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
@@ -423,10 +439,8 @@ struct CShuffleEpilogue
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
// Tiles to hold row/col scales when present
using SMType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
using SNType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
using SNType = typename GetDataType<remove_cvref_t<ScaleN>>::type;
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);