mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK TILE GEMM] set correct value to TiledMMAPermuteN_ (#2839)
- TiledMMAPermuteN_ should be set to true when config if GemmConfigPreshufflePrefill
This commit is contained in:
@@ -106,7 +106,10 @@ struct WeightPreshuffleInvoker
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
true>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -12,6 +12,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
concept HasDataType = requires { typename T::DataType; };
|
||||
|
||||
template <typename T>
|
||||
struct GetDataType
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
requires HasDataType<T>
|
||||
struct GetDataType<T>
|
||||
{
|
||||
using type = typename T::DataType; // Use T::ScaleN::DataType
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
@@ -423,10 +439,8 @@ struct CShuffleEpilogue
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType =
|
||||
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
|
||||
using SNType =
|
||||
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;
|
||||
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
|
||||
using SNType = typename GetDataType<remove_cvref_t<ScaleN>>::type;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
Reference in New Issue
Block a user