mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Add Flatmm MX FP8 (#3208)
* Use async for flatmm mxfp4 * Fix preshuffle * Add flatmm mxfp8 * Thanks, Copilot * Thanks Copilot again~
This commit is contained in:
@@ -19,6 +19,7 @@ template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"
|
||||
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
|
||||
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
|
||||
template <> struct typeToStr<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
|
||||
|
||||
template <memory_operation_enum MemOp> struct memOpToStr;
|
||||
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
|
||||
|
||||
@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
auto&& desc = transform_tensor_descriptor(
|
||||
naive_desc,
|
||||
make_tuple(make_pass_through_transform(kFlatN),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
|
||||
@@ -44,7 +44,10 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
else if(TailNumber::Odd == tail_num)
|
||||
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
|
||||
else
|
||||
{
|
||||
assert(("Wrong TailNumber!", false));
|
||||
return decltype(TailHandler<>(run_func, true, TailNumber::Even)){};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
|
||||
@@ -122,9 +122,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
@@ -138,25 +139,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
static constexpr index_t mfma_per_wg = 1; // 950 only
|
||||
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
|
||||
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
|
||||
static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp;
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize;
|
||||
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize;
|
||||
static constexpr index_t ScaleAload_num =
|
||||
kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize;
|
||||
|
||||
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
@@ -219,7 +220,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
|
||||
{
|
||||
@@ -234,7 +235,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
|
||||
{
|
||||
@@ -470,18 +471,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
{
|
||||
#ifndef __gfx950__
|
||||
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
|
||||
@@ -495,9 +494,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// const index_t iNWarp = get_warp_id() % NWarp;
|
||||
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
|
||||
static_assert(NWarp == 4);
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
@@ -506,6 +504,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto a_dram_window =
|
||||
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
@@ -520,93 +525,51 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_store_lds_window_ping = make_tile_window(
|
||||
a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto a_store_lds_window_pong = make_tile_window(
|
||||
a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
auto a_warp_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong_tmp =
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
auto packed_m_idx = mIter / number<MXdlPack>{};
|
||||
auto packed_m_rank = mIter % number<MXdlPack>{};
|
||||
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf
|
||||
{
|
||||
V4UInt_B_Buffer u = 0;
|
||||
MXFP4_B_Buffer mxfp4;
|
||||
} ub;
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_windows = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
auto window_i = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
|
||||
move_tile_window(
|
||||
window_i,
|
||||
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
|
||||
number<0>{}});
|
||||
return window_i;
|
||||
},
|
||||
number<NIterPerWarp>{});
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
b_warp_tensor_ping, b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
@@ -649,29 +612,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
};
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -700,71 +658,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
// move Scale B window to next K
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
|
||||
{
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
}
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
clear_tile(c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
using MXFP4_A_Buffer_ping =
|
||||
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf_A_ping
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_ping mxfp4;
|
||||
} ua_ping;
|
||||
|
||||
using MXFP4_A_Buffer_pong =
|
||||
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
|
||||
union UnionBuf_A_pong
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_pong mxfp4;
|
||||
} ua_pong;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
|
||||
|
||||
// preload A00,A10... from lds
|
||||
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
|
||||
|
||||
s_waitcnt_barrier</*vmcnt*/ dswrite_num_perK>();
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto main_body_implx2 = [&]() mutable {
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -791,15 +718,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -807,30 +725,26 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -838,68 +752,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// preload A(2i+1)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
// Next K
|
||||
////////////////////////////// Next K //////////////////////////////
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -926,15 +832,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -953,20 +850,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -988,39 +878,47 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// preload A(2i+2)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
iCounter--;
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
do
|
||||
{
|
||||
main_body_implx2();
|
||||
iCounter--;
|
||||
} while(iCounter > 0);
|
||||
}
|
||||
|
||||
// TAIL
|
||||
@@ -1029,18 +927,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter),
|
||||
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1055,7 +944,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
@@ -1067,10 +955,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -1089,20 +973,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1124,30 +1001,28 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// preload A(2i+1)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
@@ -1170,19 +1045,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -1204,18 +1073,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1244,20 +1106,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1279,18 +1134,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1299,32 +1147,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
static_assert(false, "Wrong TailNum");
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
scale_a_flat_window_tmp,
|
||||
scale_b_flat_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,22 +13,139 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static inline constexpr auto wg_attr_num_access =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
|
||||
? WGAttrNumAccessEnum::Single
|
||||
: WGAttrNumAccessEnum::Double;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(
|
||||
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
|
||||
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
|
||||
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access<Problem>>;
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem, typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
const auto& naive_desc = naive_view.get_tensor_descriptor();
|
||||
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
const auto rows = naive_desc.get_length(number<0>{});
|
||||
const auto cols = naive_desc.get_length(number<1>{});
|
||||
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
const index_t K0 = cols / (K1 * K2);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
const index_t M0 = rows / M1;
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto desc_0 =
|
||||
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc = transform_tensor_descriptor( //
|
||||
desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
|
||||
|
||||
return tensor_view<typename TensorView::buffer_view,
|
||||
remove_cvref_t<decltype(desc)>,
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K1; // 8
|
||||
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // M0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
@@ -36,65 +153,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
|
||||
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
|
||||
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
|
||||
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
|
||||
|
||||
constexpr index_t Pad = 4 * K2; // 4 * 32
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<M0>{},
|
||||
number<M1>{},
|
||||
number<K0>{},
|
||||
number<M2>{},
|
||||
number<M3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
|
||||
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
number<M3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(M2),
|
||||
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}));
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
a_lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
|
||||
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc_permuted;
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
{
|
||||
@@ -105,20 +227,31 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0);
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
|
||||
constexpr int K_Lane = 64 / M_Lane; // 4
|
||||
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
|
||||
constexpr int K1 = K_Thread / num_access_v; // 16
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
std::conditional_t<
|
||||
num_access_v == 1,
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -132,25 +265,36 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t K1 = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t K0 = KWavePerBlk;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
constexpr index_t kKPerThread = 32;
|
||||
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
|
||||
constexpr index_t K2 = kKPerThread / num_access_v;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>,
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
std::conditional_t< //
|
||||
num_access_v == 1,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, K2>>, // 1 64 32
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -270,6 +414,21 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
return sizeof(ADataType) *
|
||||
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeA<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user