[CK_Tile] Flatmm MX Cleanup & Explicite Offset Calculation (#3286)

This commit is contained in:
Yi DING
2025-12-02 14:21:12 +08:00
committed by GitHub
parent 46f1d740f0
commit f211156ce6
5 changed files with 206 additions and 294 deletions

View File

@@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert(
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8")
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
HasHotLoop,
TailNum>;
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,

View File

@@ -293,6 +293,15 @@ struct tile_window_with_static_distribution
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename offset_t>
CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const
{
constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{});
const auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset());
}
template <typename DataType,
typename StaticTileDistribution,
index_t i_access_unsupport_ = -1,
@@ -316,12 +325,7 @@ struct tile_window_with_static_distribution
else if constexpr(is_constant_v<offset_t>)
return offset_t::value;
else
{
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return bottom_tensor_coord_off.get_offset();
}
return get_load_offset(offset_t{});
}();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {

View File

@@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
template <typename Problem, typename PipelinePolicy = MXFlatmmPipelineAgBgCrPolicy>
struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
{
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
@@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
}
template <typename... Args>
CK_TILE_DEVICE auto operator()(Args&&... args) const
{
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
// Block GEMM Acc register tile
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
});
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
{
#ifndef __gfx950__
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
@@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto a_dram_window =
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor<Problem>(
a_copy_dram_window_tmp.get_bottom_tensor_view()),
a_copy_dram_window_tmp.get_window_lengths(),
a_copy_dram_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ADramTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
@@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
PipelinePolicy::template MakeMX_ALdsBlockDescriptor<Problem>();
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
@@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
auto a_warp_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
// Block GEMM
auto block_flatmm = BlockFlatmm();
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
// B flat DRAM window for load
// pingpong buffer for B
auto b_flat_dram_windows = generate_tuple(
auto b_flat_dram_window =
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
auto b_flat_dram_offsets = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
auto window_i = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
move_tile_window(
window_i,
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
number<0>{}});
return window_i;
return b_flat_dram_window.get_load_offset(
tuple<number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter>,
number<0>>{}) +
b_flat_dram_window.get_load_offset(
tuple<number<packed_n_rank>, number<0>>{});
},
number<NIterPerWarp>{});
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping, b_warp_tensor_pong;
@@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
scale_a_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution<Problem>());
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kM>>{}));
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
scale_b_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<NWarp * WG::kN>, number<0>>{}));
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kN>>{}));
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// ping pong buffer for scale A
statically_indexed_array<
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_pong;
statically_indexed_array<decltype(load_tile(scale_a_dram_window)), KPackIterPerWarp>,
MPackIterPerWarp>
scale_a_tile_tensor_ping, scale_a_tile_tensor_pong;
// ping pong buffer for scale B
statically_indexed_array<
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_pong;
statically_indexed_array<decltype(load_tile(scale_b_dram_window)), KPackIterPerWarp>,
NPackIterPerWarp>
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
@@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
});
// move B window to next flat K
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
// prefetch Scale A
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
// move Scale A window to next K
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// prefetch Scale B
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// move Scale B window to next K
@@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(a_dram_window, {0, kKPerBlock});
}
// initialize C
clear_tile(c_block_tile);
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
c_warp_tensors;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}(
[&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); });
});
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
@@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM 2i
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
@@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
});
// prefetch Scale A and Scale B (2i+2)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM 2i+1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter),
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM loopK-1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
static_assert(false, "Wrong TailNum");
}
return c_block_tile;
return c_warp_tensors;
}
};

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
@@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
@@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
@@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
@@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor<Problem>().get_element_space_size() /
APackedSize;
}
template <typename Problem>