mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Tile] Flatmm MX Cleanup & Explicite Offset Calculation (#3286)
This commit is contained in:
@@ -293,6 +293,15 @@ struct tile_window_with_static_distribution
|
||||
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename offset_t>
|
||||
CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const
|
||||
{
|
||||
constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
const auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset());
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename StaticTileDistribution,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
@@ -316,12 +325,7 @@ struct tile_window_with_static_distribution
|
||||
else if constexpr(is_constant_v<offset_t>)
|
||||
return offset_t::value;
|
||||
else
|
||||
{
|
||||
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return bottom_tensor_coord_off.get_offset();
|
||||
}
|
||||
return get_load_offset(offset_t{});
|
||||
}();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
|
||||
@@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
|
||||
struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
template <typename Problem, typename PipelinePolicy = MXFlatmmPipelineAgBgCrPolicy>
|
||||
struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
|
||||
@@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_DEVICE auto operator()(Args&&... args) const
|
||||
{
|
||||
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
|
||||
|
||||
// Block GEMM Acc register tile
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
});
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
{
|
||||
#ifndef __gfx950__
|
||||
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
|
||||
@@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
|
||||
static_assert(NWarp == 4);
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto a_dram_window =
|
||||
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
|
||||
make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor<Problem>(
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
|
||||
PipelinePolicy::template MakeMX_ALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
@@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
|
||||
|
||||
// B flat DRAM window for load
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_windows = generate_tuple(
|
||||
auto b_flat_dram_window =
|
||||
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
|
||||
auto b_flat_dram_offsets = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
auto window_i = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
|
||||
move_tile_window(
|
||||
window_i,
|
||||
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
|
||||
number<0>{}});
|
||||
return window_i;
|
||||
return b_flat_dram_window.get_load_offset(
|
||||
tuple<number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter>,
|
||||
number<0>>{}) +
|
||||
b_flat_dram_window.get_load_offset(
|
||||
tuple<number<packed_n_rank>, number<0>>{});
|
||||
},
|
||||
number<NIterPerWarp>{});
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping, b_warp_tensor_pong;
|
||||
|
||||
@@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution<Problem>());
|
||||
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
|
||||
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kM>>{}));
|
||||
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<NWarp * WG::kN>, number<0>>{}));
|
||||
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kN>>{}));
|
||||
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// ping pong buffer for scale A
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_pong;
|
||||
statically_indexed_array<decltype(load_tile(scale_a_dram_window)), KPackIterPerWarp>,
|
||||
MPackIterPerWarp>
|
||||
scale_a_tile_tensor_ping, scale_a_tile_tensor_pong;
|
||||
|
||||
// ping pong buffer for scale B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_pong;
|
||||
statically_indexed_array<decltype(load_tile(scale_b_dram_window)), KPackIterPerWarp>,
|
||||
NPackIterPerWarp>
|
||||
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
@@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// prefetch Scale B
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
// move Scale B window to next K
|
||||
@@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
}
|
||||
// initialize C
|
||||
clear_tile(c_block_tile);
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
|
||||
c_warp_tensors;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}(
|
||||
[&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); });
|
||||
});
|
||||
|
||||
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
|
||||
|
||||
@@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
@@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter),
|
||||
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
static_assert(false, "Wrong TailNum");
|
||||
}
|
||||
return c_block_tile;
|
||||
return c_warp_tensors;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
template <typename Problem, typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
return sizeof(ADataType) *
|
||||
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
|
||||
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
APackedSize;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user