mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
build pass
This commit is contained in:
@@ -368,118 +368,118 @@ struct CShuffleEpilogue
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
{
|
||||
const index_t iMWarp = get_warp_id() / kNWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
// const index_t iMWarp = get_warp_id() / kNWave;
|
||||
// const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
// const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
// const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
// constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
// auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
// constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
// auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
// static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
// auto in_lds_window = make_tile_window(
|
||||
// o_lds_block,
|
||||
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
// {0, 0},
|
||||
// LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
// auto out_lds_window = make_tile_window(
|
||||
// o_lds_block,
|
||||
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
// {0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
// using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
// sequence<0, 1>,
|
||||
// sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
// constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
// using TileEncodingPattern =
|
||||
// TileDistributionEncodingPattern2D<kBlockSize,
|
||||
// MPerIterationShuffle,
|
||||
// NPerIterationShuffle,
|
||||
// GetVectorSizeC(),
|
||||
// tile_distribution_pattern::thread_raked,
|
||||
// Problem::kNumWaveGroups>;
|
||||
// constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
// auto d_dram_windows = generate_tuple(
|
||||
// [&](auto idx) {
|
||||
// return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
// },
|
||||
// number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
// constexpr auto c_warp_y_lengths =
|
||||
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
// static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// block_sync_lds();
|
||||
// constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
// constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
// constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
// lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
// merge_sequences(
|
||||
// sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
// c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
// c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
// const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
// store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
// block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
// auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
auto m1 = iMLane;
|
||||
float scale_B = scale_n[nIter * NPerIterationShuffle];
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl +
|
||||
m0 * kM1 * kM2 + m1 * kM2 + m2];
|
||||
c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B;
|
||||
});
|
||||
});
|
||||
// auto m1 = iMLane;
|
||||
// float scale_B = scale_n[nIter * NPerIterationShuffle];
|
||||
// static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
// static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
// float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl +
|
||||
// m0 * kM1 * kM2 + m1 * kM2 + m2];
|
||||
// c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B;
|
||||
// });
|
||||
// });
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
// const auto ds_tensor = generate_tuple(
|
||||
// [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
// const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
// tie(c_out_tensor, c_out_tensor),
|
||||
// generate_tie(
|
||||
// [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
// tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
// if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
// {
|
||||
// store_tile(out_dram_window, c_out_tensor);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// update_tile(out_dram_window, c_out_tensor);
|
||||
// }
|
||||
// if constexpr(iAccess != num_access - 1)
|
||||
// {
|
||||
// constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
// move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
// static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
// move_tile_window(d_dram_windows[idx],
|
||||
// {step.at(number<0>{}), step.at(number<1>{})});
|
||||
// });
|
||||
// }
|
||||
// });
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user