support flatmm scaling

This commit is contained in:
Feng Shijie
2025-07-23 19:04:22 +00:00
parent 3f7d848dd3
commit 5a1183ebbd
7 changed files with 476 additions and 318 deletions

View File

@@ -282,8 +282,8 @@ struct CShuffleEpilogue
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
@@ -334,8 +334,8 @@ struct CShuffleEpilogue
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
@@ -360,7 +360,12 @@ struct CShuffleEpilogue
}
});
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM, typename ScaleN>
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM,
typename ScaleN>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
@@ -368,118 +373,133 @@ struct CShuffleEpilogue
ScaleM scale_m,
ScaleN scale_n)
{
// const index_t iMWarp = get_warp_id() / kNWave;
// const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
// const index_t iMLane = get_lane_id() / NPerXdl;
// const index_t iNLane = get_lane_id() % NPerXdl;
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
// constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
// auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
// constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
// auto o_lds_block = make_tensor_view<address_space_enum::lds>(
// static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
// auto in_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0},
// LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
// auto out_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
// using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
// sequence<0, 1>,
// sequence<MPerIterationShuffle, NPerIterationShuffle>>;
// constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
// "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
// using TileEncodingPattern =
// TileDistributionEncodingPattern2D<kBlockSize,
// MPerIterationShuffle,
// NPerIterationShuffle,
// GetVectorSizeC(),
// tile_distribution_pattern::thread_raked,
// Problem::kNumWaveGroups>;
// constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
// auto d_dram_windows = generate_tuple(
// [&](auto idx) {
// return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
// },
// number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr int kM2 = 4; // Val
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1; // Val
// static_for<0, num_access, 1>{}([&](auto iAccess) {
// block_sync_lds();
// constexpr auto idx_y_start = SFC::get_index(iAccess);
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
// constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
// constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
// lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
// merge_sequences(
// sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
// c_warp_y_index_zeros),
// merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
// c_warp_y_lengths));
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
// const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
// store_tile(in_lds_window, c_warptile_in_tensor_casted);
// block_sync_lds();
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
float scale_B =
scale_n[nIter * NPerIterationShuffle +
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl + n_xdl * NPerXdl + iNLane];
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * NumMXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
// auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// auto m1 = iMLane;
// float scale_B = scale_n[nIter * NPerIterationShuffle];
// static_for<0, kM0, 1>{}([&](auto m0) {
// static_for<0, kM2, 1>{}([&](auto m2) {
// float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl +
// m0 * kM1 * kM2 + m1 * kM2 + m2];
// c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B;
// });
// });
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
float scale_A =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + m2];
lds_tile.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
scale_A * scale_B;
});
});
});
});
// const auto ds_tensor = generate_tuple(
// [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
// const auto c_ds_tiles = concat_tuple_of_reference(
// tie(c_out_tensor, c_out_tensor),
// generate_tie(
// [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
// tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// if constexpr(MemoryOperation == memory_operation_enum::set)
// {
// store_tile(out_dram_window, c_out_tensor);
// }
// else
// {
// update_tile(out_dram_window, c_out_tensor);
// }
// if constexpr(iAccess != num_access - 1)
// {
// constexpr auto step = SFC::get_forward_step(iAccess);
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
// move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
// static_for<0, NumDTensor, 1>{}([&](auto idx) {
// move_tile_window(d_dram_windows[idx],
// {step.at(number<0>{}), step.at(number<1>{})});
// });
// }
// });
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}
};
} // namespace ck_tile