preshuffle reformat

This commit is contained in:
lalala-sh
2025-09-29 02:17:20 +00:00
committed by valarLip
parent e44e37654c
commit 67ded700d1
30 changed files with 179 additions and 78 deletions

View File

@@ -277,7 +277,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
};
template <typename PrecType>
@@ -301,7 +301,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
};
template <typename PrecType>

View File

@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
// return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
#endif

View File

@@ -195,23 +195,32 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
int divisor = 1;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
divisor,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int MaxVecSize = 16 / sizeof(T);
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess = std::min(MaxVecSize, GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
1,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
@@ -220,19 +229,21 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int MaxVecSize = 16 / sizeof(T);
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess = std::min(MaxVecSize, GemmConfig::K_Warp_Tile / KLane);
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
}
template <typename CDataType>

View File

@@ -16,6 +16,7 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"

View File

@@ -9,5 +9,7 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -12,5 +12,7 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -10,5 +10,7 @@
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -1269,7 +1269,7 @@ struct Swish
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1288,7 +1288,7 @@ struct SoftRelu
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1310,7 +1310,7 @@ struct Power
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1329,7 +1329,7 @@ struct ClippedRelu
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1346,7 +1346,7 @@ struct LeakyRelu
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
Elu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1363,7 +1363,7 @@ struct Elu
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
Logistic(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const

View File

@@ -8,5 +8,7 @@
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -45,10 +45,11 @@ template <typename AsDataType_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -71,6 +72,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -123,6 +125,7 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
@@ -228,7 +231,8 @@ struct CShuffleEpilogue
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
@@ -281,14 +285,31 @@ struct CShuffleEpilogue
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});

View File

@@ -14,5 +14,7 @@
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -60,5 +60,7 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -16,5 +16,7 @@
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -30,8 +30,8 @@
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
@@ -70,5 +70,7 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -109,6 +109,7 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
});
});

View File

@@ -158,11 +158,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
if constexpr(get_warp_size() % K0 == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
constexpr index_t M2 = get_warp_size() / K0;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
@@ -180,18 +180,18 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
constexpr index_t KWave = K0 / get_warp_size();
constexpr index_t M0 = BlockSize / get_warp_size() / KWave;
constexpr index_t M1 = MPerBlock / M0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M0, M1>, sequence<KWave, get_warp_size(), K1>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<0, 0>, sequence<1>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
}
}
@@ -211,10 +211,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType);
constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize);
constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
static_assert(KFragment * KItemsPerLoad == KBPerLoad);
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
@@ -224,9 +229,10 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KFragment, KWavePerBlk, KThdPerWave, KItemsPerLoad>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
@@ -284,6 +290,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDS_WarpTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
constexpr int KLane = get_warp_size() / MPerXdl;
constexpr int KPerThread = KPerXdl / KLane;
constexpr int MaxVecSize = 16 / sizeof(ADataType);
constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread);
constexpr int KFragment = KPerThread / KItemsPerLoad;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<Repeat>,
tuple<sequence<MPerXdl>, sequence<KFragment, KLane, KItemsPerLoad>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
{

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
@@ -581,13 +580,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -602,16 +601,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
@@ -661,8 +654,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -709,8 +701,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -786,8 +777,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -867,8 +857,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});

View File

@@ -17,5 +17,7 @@
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -12,5 +12,7 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -7,5 +7,7 @@
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -10,5 +10,7 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -7,5 +7,7 @@
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,5 +6,7 @@
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -11,5 +11,7 @@
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -11,5 +11,7 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -10,5 +10,7 @@
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,5 +6,7 @@
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,5 +6,7 @@
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -8,5 +8,7 @@
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"