diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 07b925d0eb..af740dc0f1 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -277,7 +277,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0; }; template @@ -301,7 +301,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0; }; template diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 0f323cb0e3..dca4ef8f40 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) try { #if CK_TILE_USE_WMMA - return !run_gemm_example(arg_parser); + // return !run_gemm_example(arg_parser); #else return !run_gemm_example(arg_parser); #endif diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e6875f97d5..2abc66bc4a 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -195,23 +195,32 @@ auto shuffle_b(const ck_tile::HostTensor& t) } else { - int divisor = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + int divisor = 1; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + divisor, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, GemmConfig::K_Warp_Tile / KLane); + + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + 1, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } } @@ -220,19 +229,21 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, GemmConfig::K_Warp_Tile / KLane); constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); + k_ / ItemsPerAccess, + ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); } template diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 86110d57ec..d815b1db40 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -16,6 +16,7 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_dropout_randval.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 1768c802d5..6c0972e10a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -9,5 +9,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index ca0088c812..5822d7b91b 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 7c6adc3ec2..eff2d625b3 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" -#include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 4858245ec4..7f2303932e 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 221592ee10..6e93fc846c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1269,7 +1269,7 @@ struct Swish struct SoftRelu { - SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1288,7 +1288,7 @@ struct SoftRelu struct Power { Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma){}; + : alpha_(alpha), beta_(beta), gamma_(gamma) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1310,7 +1310,7 @@ struct Power struct ClippedRelu { - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1329,7 +1329,7 @@ struct ClippedRelu struct LeakyRelu { - LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1346,7 +1346,7 @@ struct LeakyRelu struct Elu { - Elu(float alpha = 1.f) : alpha_(alpha){}; + Elu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1363,7 +1363,7 @@ struct Elu struct Logistic { - Logistic(float alpha = 1.f) : alpha_(alpha){}; + Logistic(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 6cc0fa8540..ec5a8ef445 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 585a5f5b42..e9c97f2daa 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -45,10 +45,11 @@ template + index_t kNumWaveGroups_ = 1, + bool FixedVectorSize_ = false, + index_t VectorSizeC_ = 1, + bool TiledMMAPermuteN_ = false, + index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; @@ -71,6 +72,7 @@ struct CShuffleEpilogueProblem static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -123,6 +125,7 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; + static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; @@ -228,7 +231,8 @@ struct CShuffleEpilogue } }(); static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); - static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); + static constexpr index_t NumNXdlPerWavePerShuffle = + max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple)); static constexpr auto MNPerIterationShuffle = [] { constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; @@ -281,14 +285,31 @@ struct CShuffleEpilogue CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() { - constexpr auto block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto block_outer_dstr_encoding = [] { + if constexpr(BlockedXDLN_PerWarp == 1) + { + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + } + else + { + constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; + // BlockedLayout + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + }(); constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 1714789e63..41463e6a2d 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -14,5 +14,7 @@ #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 31de21a726..6b25c089bd 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -60,5 +60,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index ddb64a2189..71721f3408 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -16,5 +16,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e..c6f71dc399 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,8 +30,8 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" @@ -70,5 +70,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp index 9036d48b08..b4c87f9cd0 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp @@ -109,6 +109,7 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1 merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); }); }); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index f1c8f2ec9b..80e6b525bc 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -158,11 +158,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; // coalesce reading for each blocks - if constexpr(get_warp_size() % (M2 * K0) == 0) + if constexpr(get_warp_size() % K0 == 0) { constexpr index_t M1 = BlockSize / get_warp_size(); + constexpr index_t M2 = get_warp_size() / K0; static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t M0 = MPerBlock / (M2 * M1); @@ -180,18 +180,18 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy } else { - constexpr index_t M0 = BlockSize / get_warp_size(); - constexpr index_t M1 = MPerBlock / (M2 * M0); - static_assert(M0 * M1 * M2 == MPerBlock, - "Incorrect M0, M1, M2 configuration! " - "M0, M1, M2 must cover whole MPerBlock!"); + constexpr index_t KWave = K0 / get_warp_size(); + constexpr index_t M0 = BlockSize / get_warp_size() / KWave; + constexpr index_t M1 = MPerBlock / M0; + return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<1, 2>>{}); } } } @@ -211,10 +211,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy #else constexpr index_t KRepeatInWave = 1; #endif - constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim - constexpr index_t KWavePerBlk = 1; - constexpr index_t KRepeat = 1; + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType); + constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize); + constexpr index_t KFragment = KBPerLoad / KItemsPerLoad; + static_assert(KFragment * KItemsPerLoad == KBPerLoad); + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; @@ -224,9 +229,10 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? - tuple, // second direction - sequence>, // first direction + sequence, // ? + tuple, // second direction + sequence>, // first + // direction // wave in blk, // thd in wave // // tuple, sequence<0, 1, 2>>, // which direction @@ -284,6 +290,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALDS_WarpTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2); + + constexpr int Repeat = TileShape::BlockWarps::at(number<1>{}); + + constexpr int KLane = get_warp_size() / MPerXdl; + constexpr int KPerThread = KPerXdl / KLane; + + constexpr int MaxVecSize = 16 / sizeof(ADataType); + constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread); + constexpr int KFragment = KPerThread / KItemsPerLoad; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 2>, + sequence<0, 2>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle() { diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 670f4b0575..183a42349e 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -581,13 +580,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + PipelinePolicy::template MakeALDS_WarpTileDistribution()); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + PipelinePolicy::template MakeALDS_WarpTileDistribution()); statically_indexed_array< statically_indexed_array, @@ -602,16 +601,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; move_tile_window(a_warp_windows_ping(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - move_tile_window(a_warp_windows_pong(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); @@ -661,8 +654,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -709,8 +701,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); @@ -786,8 +777,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); @@ -867,8 +857,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 9f90050899..478f348146 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -17,5 +17,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 09b50f26b0..1dd13b6246 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 93664ea138..2307b05190 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index afbb817db1..9ce22137bf 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 7dc3e8b7e7..aa074b7f9f 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 1cc3d9cbc3..46512c57fe 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index a6721c9305..d628e9c945 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 610541b2e4..00afcf4aed 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index dc164dc1a0..1aa14c69e1 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index b23e869d81..d559dc15e2 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 1dc563f757..040c6b8ddc 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d0a810de4f..d9657a9764 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp"