Dev/a8w4 and a8w8splitk (#3447)

* Ck moe bs splitk pr (#3440)

* splitk kick-off. Compilation fail

* splitk hack pass

* fix scale offset calc.

* clang-format for a8w8_moe_blk_gemm1 splitk change

* fix testcase error

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>

* Zan/moe a8w4 (#3441)

* update

* update

* update ck moe a8w4

* update

* update

* update

* compile pass

* update

* update

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* support new a8w4 kernel

* update

* update ck_tile

* re format

* update

* update

* fix conflict

* fix build

* update ck_tile moe

* fix clang format

* fix the problem

* fix accruacy issue

* fix

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
yadaish
2025-12-19 09:26:52 +08:00
committed by GitHub
parent ba897f8435
commit c0ee71d735
13 changed files with 2911 additions and 139 deletions

View File

@@ -217,6 +217,7 @@ struct MoeFlatmmKernel
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
@@ -241,12 +242,24 @@ struct MoeFlatmmKernel
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, pk_fp4_t>;
static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
static constexpr bool MXF8F6F4MFMA =
#ifdef __gfx950__
AQUANT_Pipeline && BMXFP4_Pipeline;
#else
false;
#endif
static constexpr int MXFP4M_Pack = 2;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;
static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1;
static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1;
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
@@ -659,23 +672,95 @@ struct MoeFlatmmKernel
}
}();
auto scale_n = kargs.scale_n;
constexpr int GranularityK = decltype(scale_n)::GranularityK;
const auto& scale_a_tensor_view = [&]() {
auto scale_m_desc = kargs.scale_m;
if constexpr(AQUANT_Pipeline)
{
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0
? 1
: decltype(scale_m_desc)::GranularityK;
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
}
else
{
constexpr int AGranularityK = 32;
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
make_tuple(scale_k_packs * KThreadPerXdl, 1),
number<8>{},
number<1>{});
}
}();
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
const auto scale_b_flat_view = [&]() {
auto scale_n = kargs.scale_n;
constexpr int BGranularityK =
decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
if constexpr(AQUANT_Pipeline)
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_navie_desc,
make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_n.ptr) +
expert_id * kargs.N * scale_k / 4,
scale_b_desc);
}
else
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
using ScaleType = e8m0_t;
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view,
b_flat_tensor_view,
c_tensor_view,
scale_a_tensor_view,
scale_b_flat_view);
}
template <typename TensorView>
@@ -718,7 +803,7 @@ struct MoeFlatmmKernel
}
}();
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3), views.at(I4));
}
template <typename PadView>
@@ -747,7 +832,7 @@ struct MoeFlatmmKernel
}
}();
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
constexpr bool isNonInterleaveGateUp = !IsGateUp || BMXFP4_Pipeline;
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
@@ -766,17 +851,40 @@ struct MoeFlatmmKernel
output_N_offset});
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
auto a_scale_block_window = make_tile_window(
views.at(I3),
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{coord_m / M_Pack, 0});
constexpr int XDLPerLoadScaleB =
MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
auto scale_block_window =
make_tile_window(views.at(I3),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
XDLPerLoadScaleB / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
auto b_scale_block_window = [&]() {
if constexpr(MXF8F6F4MFMA)
{
return make_tile_window(
views.at(I4),
make_tuple(number<TilePartitioner::NPerBlock / N_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{coord_n / N_Pack, 0});
}
else
{
return make_tile_window(
views.at(I4),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
XDLPerLoadScaleB / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
}
}();
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
return make_tuple(a_block_window,
b_flat_block_window,
c_block_window,
a_scale_block_window,
b_scale_block_window);
}
template <class MoeFlatmmKernelArgs>
@@ -831,7 +939,6 @@ struct MoeFlatmmKernel
if(coord_m >= max_token_id)
return;
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
const auto row_idx =
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
@@ -864,9 +971,10 @@ struct MoeFlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& a_scale_block_window = gemm_tile_windows.at(I3);
const auto& b_scale_block_window = gemm_tile_windows.at(I4);
auto a_gather_block_tile =
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
@@ -876,17 +984,32 @@ struct MoeFlatmmKernel
a_offsets); // K DRAM tile window for
auto c_block_tile = [&] {
if constexpr(MXFP4_Pipeline)
if constexpr(BMXFP4_Pipeline)
{
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
// BMXFP4_Pipeline uses gate-up interleave 16 layout for weight
// so don't need extra processing
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
scale_block_window, // weight scale with granularityK = 32
num_loop,
kargs.k_padded_zeros,
smem_ptr_ping,
smem_ptr_pong);
if constexpr(AQUANT_Pipeline)
{
return FlatmmPipeline{}(
a_gather_block_tile,
b_block_window,
a_scale_block_window, // weight scale with granularityK = 32
b_scale_block_window, // weight scale with granularityK = 32
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
else
{
return FlatmmPipeline{}(
a_gather_block_tile,
b_block_window,
b_scale_block_window, // weight scale with granularityK = 32
num_loop,
kargs.k_padded_zeros,
smem_ptr_ping,
smem_ptr_pong);
}
}
else
{
@@ -964,7 +1087,7 @@ struct MoeFlatmmKernel
constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
static_for<0, MRepeat, 1>{}([&](auto mIter) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
@@ -1059,7 +1182,7 @@ struct MoeFlatmmKernel
number<1>{});
auto exp_bias_window = make_tile_window(
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
permute_tensor_view(exp_bias_view, number<(BMXFP4_Pipeline && !IsInputGemm)>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock > {}),
@@ -1101,7 +1224,7 @@ struct MoeFlatmmKernel
ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
ExpWeightBuffer exp_weight_buffer;
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
{
scale_m_window.load(scale_m_buffer);
scale_n_buffer = load_tile(scale_n_window);
@@ -1233,7 +1356,7 @@ struct MoeFlatmmKernel
auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
{
gate_tensor.get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
@@ -1260,7 +1383,7 @@ struct MoeFlatmmKernel
auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
lds_tile[lds_stage].get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
if constexpr(EnableBias)

View File

@@ -156,7 +156,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 128 / 4 = 32
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
constexpr int K0 = K_Lane; // 4
@@ -236,4 +236,513 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
};
struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
template <typename Problem>
static inline constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Single;
// std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
// ? WGAttrNumAccessEnum::Single
// : WGAttrNumAccessEnum::Double;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access<Problem>>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
Make_F8AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
constexpr int DynamicTileOffsetFlag = 0;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
// implement swizzle pattern on global side
// because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
auto swizzle_a_dram_view_1 = transform_tensor_view(
naive_view,
make_tuple(
// M-dim is not affected by swizzle pattern
make_unmerge_transform(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
// K-dim is the swizzle dimension
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}));
auto swizzle_a_dram_view_2 = transform_tensor_view(
swizzle_a_dram_view_1,
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
return transform_tensor_view(
swizzle_a_dram_view_2,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = MPerBlock == 16
? GetSmemPackA<Problem>() * APackedSize / 4
: GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = get_warp_size() / K1; // 8
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 16
// constexpr index_t Pad = 0; // 4 * 16
// TODO: fix lds_a swizzle
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<M0>{},
number<M1>{},
number<K0>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF8_ReadALdsBlockDescriptor()
{
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF8_WriteALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_naive_tensor_descriptor(make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
make_tuple(number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
constexpr int K_Lane = 64 / M_Lane; // 4
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
// constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t num_access_v = 2;
constexpr int K1 = K_Thread / num_access_v; // 16
return make_static_tile_distribution(
std::conditional_t<
num_access_v == 1,
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>,
tile_distribution_encoding< //
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t kKPerThread = 32;
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t K2 = kKPerThread / num_access_v;
return make_static_tile_distribution(
std::conditional_t< //
num_access_v == 1,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
constexpr index_t K_Lanes = 64 / M_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = M_Lanes;
constexpr index_t Y1 = M_Warps;
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
constexpr index_t K_Lanes = 64 / N_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = N_Lanes;
constexpr index_t Y1 = N_Warps;
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Warps>, // ?
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
constexpr index_t MWavePerBlk = M_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Wrap>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Wrap>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA<Problem>();
}
};
} // namespace ck_tile