update scale-preshuffle for MXF4

This commit is contained in:
Feng Shijie
2025-08-13 10:48:53 +00:00
parent edb58d0680
commit 732ebdee8b
6 changed files with 376 additions and 401 deletions

View File

@@ -11,7 +11,7 @@ struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;

View File

@@ -97,17 +97,17 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::MixedPrecFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -134,7 +134,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
FlatmmConfig::TiledMMAPermuteN>>;
using Kernel =
ck_tile::MixedPrecFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -282,10 +282,11 @@ float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
// TODO (sizeof(BDataType) / 2)
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
@@ -366,23 +367,26 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
int n_ = scale.get_lengths()[1];
int k_ = scale.get_lengths()[0];
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
constexpr int K_Pack = 2; // fixed for mxfp4
constexpr int N_Pack = 2; // fixed for mxfp4
constexpr int GranularityK = 32; // fixed for mxfp4
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
ck_tile::HostTensor<T> shfl_scale({
k_ / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Repeat,
FlatmmConfig::N_Warp,
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
}
#include "run_mixed_prec_flatmm.inc"

View File

@@ -59,8 +59,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
}
else if(init_method == 1)
{
@@ -166,7 +165,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float rtol = 1e-3;
const float rtol = 5e-3;
const float atol = 1e-3;
pass = ck_tile::check_err(

View File

@@ -13,11 +13,8 @@
namespace ck_tile {
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
int SupportArch = 0> // 0 means no arch restriction
struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
using Underlying = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;

View File

@@ -20,26 +20,32 @@ template <typename ADataType_,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
BlockGemmShape_,
Traits_,
Scheduler_,
HasHotLoop_,
TailNum_,
ComputeDataType_>
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
BlockGemmShape_,
Traits_,
Scheduler_,
HasHotLoop_,
TailNum_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
using QuantType = BDataType_;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t flatKPerWarp = 128;
static constexpr int MXF4ScaleGranularityK = 32;
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
static constexpr int ContinuousScaleNPerThread = 2; // it's fixed for fp4
static constexpr int ContinuousScaleKPerThread = 2; // it's fixed for fp4
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MixedPrecFlatmmPipelineAgBgCrPolicy>
struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
template <typename Problem, typename PipelinePolicy = F16xMXF4FlatmmPipelineAgBgCrPolicy>
struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
: FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
{
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
@@ -117,6 +123,31 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
static constexpr int MXFP4PackedSize = 2;
static constexpr int ScaleKFlatPerWarp =
ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
static constexpr int XDLK_PerThread =
WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
static constexpr int XDL_PerWeightK = 4; // 4
static constexpr int XDL_PerScaleK = XDL_PerWeightK * ContinuousScaleKPerThread; // 4
static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
static_assert(KIterPerWarp % XDL_PerScaleK == 0);
static_assert(NIterPerWarp % XDL_PerScaleN == 0);
static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -142,27 +173,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', WG::kM, WG::kN, WG::kK),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
@@ -502,6 +515,15 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto A_XDL_TileDist = make_static_tile_distribution(typename WG::AWarpDstrEncoding{});
auto A_Lds_TileDist =
PipelinePolicy::template MakeFp16xF4_DS_WRITE_ATileDistribution<Problem>();
auto A_Lds_Stride = WG::kK;
// auto A_XDL_TileDist = PipelinePolicy::template
// MakeF16xF4_ALDS_TileDistribution<Problem>(); auto A_Lds_TileDist =
// PipelinePolicy::template MakeADramTileDistribution<Problem>(); auto A_Lds_Stride = 8;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -513,27 +535,26 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
A_Lds_TileDist);
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
A_Lds_TileDist);
auto A_Warp_Dist = PipelinePolicy::template MakeF16xF4_ADramDistribution<Problem>();
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
A_Warp_Dist);
A_XDL_TileDist);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
A_Warp_Dist);
A_XDL_TileDist);
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -545,23 +566,26 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
MIterPerWarp>
a_warp_windows_pong;
constexpr int KStridePerIter = 8;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
// move_tile_window(
// a_warp_windows_ping(mIter)(kIter),
// {mIter * MPerBlockPerIter,
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
// move_tile_window(
// a_warp_windows_pong(mIter)(kIter),
// {mIter * MPerBlockPerIter,
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
@@ -570,12 +594,6 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
constexpr int XDLPerLoadK = 4;
constexpr int NRepeatPerScaleLoad = 2;
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
constexpr int QuantNPerWarp = NIterPerWarp / NRepeatPerScaleLoad;
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
@@ -588,41 +606,37 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
constexpr int ScaleB_BlockK = 16 * 2 * 4;
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
auto scale_b_flat_dram_window = make_tile_window(
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<ScaleB_BlockK>{}),
make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
scale_b_flat_window.get_window_origin(),
scale_b_flat_distribution);
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(scale_b_flat_dram_window), QuantKPerWarp>,
QuantNPerWarp>
scale_b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
QuantNPerWarp>
statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
ScaleNPerWarp>
scale_b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
ScaleNPerWarp>
scale_b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
QuantNPerWarp>
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
ScaleNPerWarp>
scale_b_warp_tensor_pong;
// HEAD
@@ -633,35 +647,42 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
// A_Lds_TileDist may differ with ADramTileDistribution
auto a_block_tile_transformed = make_static_distributed_tensor<ComputeType>(A_Lds_TileDist);
a_block_tile_transformed.get_thread_buffer() =
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
@@ -689,64 +710,44 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto dequant_B = typename WG::BWarpTensor{};
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
#if defined(__gfx942__)
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
lane_scale);
return lane_scale;
#endif
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx < 2)
{
lane_scale = v2scale[0];
}
else
{
lane_scale = v2scale[1];
}
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx % 2 == 0)
{
lane_scale = v2scale[0];
}
else
{
lane_scale = v2scale[1];
}
return lane_scale;
};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
auto deq_fn = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto b_idx_k = xdl_kIter % number<XDLPerLoadK>{};
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
auto use_scale = scale;
use_scale.data = perm_scale(scale.data, b_idx_k);
if constexpr(xdl_nIter == 0)
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
{
printf("laneid = %2u xdl-k=%2d use-scale = "
"%.2f\n",
threadIdx.x,
int(xdl_kIter),
float(use_scale));
}
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
number<i>{},
pk_fp4_to_fp16x2(
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
static_cast<float>(use_scale)));
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
}
};
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4_to_compute_v2(
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
static_cast<float>(scale)));
});
};
@@ -755,34 +756,43 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
a_block_tile_transformed.get_thread_buffer() =
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
@@ -801,11 +811,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -835,8 +845,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -849,34 +859,42 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Next K
// prefetch B(2i+2)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
a_block_tile_transformed.get_thread_buffer() =
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
@@ -894,11 +912,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -927,8 +946,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -945,34 +964,43 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(loopK)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
a_block_tile_transformed.get_thread_buffer() =
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
@@ -986,11 +1014,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -1039,11 +1068,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -1084,11 +1114,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);

View File

@@ -7,76 +7,110 @@
namespace ck_tile {
struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
static constexpr index_t KBPerLoad = 32;
static constexpr index_t N_Pack = 2; // it's fixed for fp4
static constexpr index_t K_Pack = 2; // it's fixed for fp4
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<4>,
tuple<sequence<16>, sequence<4, 4, 8>>,
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_DS_WRITE_ATileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
// unmerge K0 to K16_i x K4_1 x K4_2
// then exchange the order of K4_1 and K4_2
constexpr index_t XDL_PerKBLoad = 4;
constexpr index_t K128_Cnt = K0 / XDL_PerKBLoad / XDL_PerKBLoad;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K128_Cnt, XDL_PerKBLoad, XDL_PerKBLoad, K1>>,
tuple<sequence<1>, sequence<1, 2, 2, 2>>,
tuple<sequence<1>, sequence<2, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
constexpr int M0 = TileShape::WarpTile::at(I0);
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
constexpr int K0 = K_Lane; // 4
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Repeat>,
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
@@ -86,37 +120,34 @@ struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = 32;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
sequence<WaveRepeat>, // ?
tuple<sequence<NWavePerBlk, N_Pack>, // second
// direction
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
@@ -130,111 +161,25 @@ struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t N_Repeat = TileShape::kN / TileShape::WarpTile::at(I1) / N_Warp;
constexpr index_t N_Pack = N_Repeat;
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
constexpr index_t KBPerLoad = XDLPerBlock * N_Pack;
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t K_Pack = XDLPerBlock / K_Lane;
// constexpr index_t RepeatScale = TileShape::WarpTile::at(I2) / ;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // ?
tuple<sequence<NWavePerBlk>, // second direction
sequence<K_Lane, 16, N_Pack * K_Pack>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<1>, sequence<2, 2>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size >= (K2 * M0))
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename
// Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
tile_distribution_encoding<
sequence<>, // ?
tuple<sequence<NWavePerBlk>, // second direction
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<1>, sequence<2, 2>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
}
};