update scale for mxfp4

This commit is contained in:
Feng Shijie
2025-08-11 07:59:47 +00:00
parent 8ba1c708dc
commit 200a11afc8
8 changed files with 483 additions and 177 deletions

View File

@@ -40,6 +40,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int N_Pack = 2;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -47,6 +48,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
@@ -149,7 +151,21 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
}
}();
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
auto scale_n = kargs.scale_n_ptr;
index_t FlatScaleK =
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view =
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
}
template <typename TensorView>
@@ -215,7 +231,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
}
template <typename PadView>
@@ -275,6 +291,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
auto scale_block_window =
make_tile_window(views.at(I4),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
}
@@ -304,8 +326,13 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_flat_block_window,
scale_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||

View File

@@ -371,37 +371,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<4>,
tuple<sequence<16>, sequence<4, 4, 8>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
@@ -438,42 +407,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = 32;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{

View File

@@ -7,6 +7,7 @@
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
namespace ck_tile {
@@ -37,7 +38,7 @@ struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr index_t flatKPerWarp = 128;
};
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
template <typename Problem, typename PipelinePolicy = MixedPrecFlatmmPipelineAgBgCrPolicy>
struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
: FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
{
@@ -456,10 +457,14 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// __builtin_amdgcn_sched_barrier(0);
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
template <typename ADramBlockWindowTmp,
typename AElementFunction,
typename BFlatBlockWindowTmp,
typename DequantBFlatWindow>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const DequantBFlatWindow& scale_b_flat_window,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
@@ -565,35 +570,61 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
constexpr int XDLPerLoadK = 4;
constexpr int XDLPerLoadK = 4;
constexpr int NRepeatPerScaleLoad = 2;
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
constexpr int QuantNPerWarp = NIterPerWarp / NRepeatPerScaleLoad;
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
auto scale_b_flat_distribution =
PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
constexpr int ScaleB_BlockK =
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
auto scale_b_flat_dram_window = make_tile_window(
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<ScaleB_BlockK>{}),
scale_b_flat_window.get_window_origin(),
scale_b_flat_distribution);
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(scale_b_flat_dram_window), QuantKPerWarp>,
QuantNPerWarp>
scale_b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
QuantNPerWarp>
scale_b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
QuantNPerWarp>
scale_b_warp_tensor_pong;
// HEAD
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
@@ -603,6 +634,19 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -613,6 +657,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
@@ -643,12 +688,56 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto dequant_B = typename WG::BWarpTensor{};
auto deq_fn = [&](auto& quant_weight_tensor, auto sub_idx) {
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
#if defined(__gfx942__)
return lane_scale;
#endif
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx < 2)
{
lane_scale = v2scale[0];
}
else
{
lane_scale = v2scale[1];
}
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx % 2 == 0)
{
return v2scale[0];
}
else
{
return v2scale[1];
}
};
auto deq_fn = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto b_idx_k = xdl_kIter % number<XDLPerLoadK>{};
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
packed_scale = perm_scale(packed_scale, b_idx_k);
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
if constexpr(xdl_nIter % 2 != 0)
{
scale_ptr++;
}
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
number<i>{},
fp16x2_t(quant_weight_tensor.get_thread_buffer()[sub_idx * ScalarCnt / 2 + i]));
pk_fp4_to_fp16x2(
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
*scale_ptr));
});
};
@@ -690,7 +779,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -721,6 +813,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -765,7 +858,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -795,6 +891,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -839,7 +936,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -889,7 +989,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -931,7 +1034,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
kIter / number<XDLPerLoadK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
@@ -964,9 +1070,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename DequantBFlatWindow>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const DequantBFlatWindow& scale_b_flat_window,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
@@ -975,6 +1084,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
b_flat_dram_block_window_tmp,
scale_b_flat_window,
num_loop,
p_smem_ping,
p_smem_pong);

View File

@@ -0,0 +1,241 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
namespace ck_tile {
struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<4>,
tuple<sequence<16>, sequence<4, 4, 8>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = 32;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t N_Repeat = TileShape::kN / TileShape::WarpTile::at(I1) / N_Warp;
constexpr index_t N_Pack = N_Repeat;
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
constexpr index_t KBPerLoad = XDLPerBlock * N_Pack;
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t K_Pack = XDLPerBlock / K_Lane;
// constexpr index_t RepeatScale = TileShape::WarpTile::at(I2) / ;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // ?
tuple<sequence<NWavePerBlk>, // second direction
sequence<K_Lane, 16, N_Pack * K_Pack>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<1>, sequence<2, 2>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size >= (K2 * M0))
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename
// Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
};
} // namespace ck_tile