mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
finished with the development on test kernel
This commit is contained in:
@@ -14,7 +14,7 @@ mkdir build && cd build
|
||||
# (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the copy kernel executable
|
||||
make test_copy -j
|
||||
make test_copy_kernel -j
|
||||
```
|
||||
This will result in an executable `build/bin/test_copy_kernel`
|
||||
|
||||
|
||||
@@ -56,13 +56,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
using Vector = ck_tile::sequence<1, 4>;
|
||||
using Vector = ck_tile::sequence<1, 8>;
|
||||
constexpr bool AsyncCopy = true;
|
||||
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape, AsyncCopy>;
|
||||
using Kernel = ck_tile::TileCopy<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 128;
|
||||
|
||||
@@ -50,11 +50,12 @@ struct TileCopyShape
|
||||
static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!");
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename BlockShape_>
|
||||
template <typename XDataType_, typename BlockShape_, bool AsyncCopy_>
|
||||
struct TileCopyProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool AsyncCopy = AsyncCopy_;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -63,6 +64,8 @@ struct TileCopy
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
static constexpr bool AsyncCopy = Problem::AsyncCopy;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
|
||||
{
|
||||
@@ -156,17 +159,26 @@ struct TileCopy
|
||||
|
||||
if(my_id == warp_id)
|
||||
{
|
||||
// load from DRAM to registers
|
||||
load_tile(dram_tile, x_block_window);
|
||||
if constexpr (AsyncCopy == false) {
|
||||
// load from DRAM to registers
|
||||
load_tile(dram_tile, x_block_window);
|
||||
|
||||
// store in lds
|
||||
store_tile(x_block_lds_window_no_dist, dram_tile);
|
||||
// store in lds
|
||||
store_tile(x_block_lds_window_no_dist, dram_tile);
|
||||
|
||||
// read from lds to registers
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
// read from lds to registers
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
} else {
|
||||
async_load_tile(x_block_lds_window_no_dist, x_block_window);
|
||||
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
move_tile_window(x_block_window, {0, S::Block_N});
|
||||
|
||||
@@ -369,62 +369,6 @@ struct tile_window_with_static_distribution
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
@@ -436,8 +380,6 @@ struct tile_window_with_static_distribution
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user