try to remove c_shuffle_lds

This commit is contained in:
sjfeng
2025-07-27 17:24:08 +08:00
parent 1264f4d2ab
commit bfb9f4002f
3 changed files with 167 additions and 71 deletions

View File

@@ -257,6 +257,84 @@ struct CShuffleEpilogue
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr int MRepeat = kMPerBlock / MPerIterationShuffle;
constexpr int NRepeat = kNPerBlock / NPerIterationShuffle;
static_assert(MPerXdl == 16);
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
static_for<0, MRepeat, 1>{}([&](auto mIter) {
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
c_out_tensor.get_thread_buffer()[n_idx + 0 * c_warp_y_lengths.product()] =
type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
c_out_tensor.get_thread_buffer()[n_idx + 1 * c_warp_y_lengths.product()] =
type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]);
c_out_tensor.get_thread_buffer()[n_idx + 2 * c_warp_y_lengths.product()] =
type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]);
c_out_tensor.get_thread_buffer()[n_idx + 3 * c_warp_y_lengths.product()] =
type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
});
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <class, typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,