diff --git a/example/ck_tile/36_copy/CMakeLists.txt b/example/ck_tile/36_copy/CMakeLists.txt new file mode 100644 index 0000000000..d1b9ba923c --- /dev/null +++ b/example/ck_tile/36_copy/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp) +target_compile_options(test_copy_kernel PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) \ No newline at end of file diff --git a/example/ck_tile/36_copy/README.md b/example/ck_tile/36_copy/README.md new file mode 100644 index 0000000000..7856f0b4bd --- /dev/null +++ b/example/ck_tile/36_copy/README.md @@ -0,0 +1,31 @@ +# Copy Kernel +This folder contains basic setup code designed to provide a platform for novice +CK_Tile kernel developers to test basic functionality with minimal additional +code compared to the functional code. Sample functional code for a simple +tile distribution for DRAM window and LDS window are provided and data is moved +from DRAM to registers, registers to LDS, LDS to registers and finally data +is moved to output DRAM window for a simple copy operation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture +# (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the copy kernel executable +make test_copy -j +``` +This will result in an executable `build/bin/test_copy_kernel` + +## example +``` +args: + -m input matrix rows. (default 64) + -n input matrix cols. (default 8) + -id warp to use for computation. (default 0) + -v validation flag to check device results. (default 1) + -prec datatype precision to use. (default fp16) + -warmup no. of warmup iterations. (default 50) + -repeat no. of iterations for kernel execution time. (default 100) +``` \ No newline at end of file diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp new file mode 100644 index 0000000000..81ea5255fc --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include +#include "test_copy.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "64", "m dimension") + .insert("n", "8", "n dimension") + .insert("id", "0", "warp to use") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "50", "cold iter") + .insert("repeat", "100", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t warp_id = arg_parser.get_int("id"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_ref({m, n}); + ck_tile::HostTensor y_host_dev({m, n}); + + // ck_tile::FillConstant{1.f}(x_host); + ck_tile::half_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + x_host(i, j) = value++; + } + } + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, 4>; + + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Kernel = ck_tile::TileCopy; + + constexpr ck_tile::index_t kBlockSize = 128; + constexpr ck_tile::index_t kBlockPerCu = 1; + std::cout << "block size " << kBlockSize << std::endl; + std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl; + std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N + << std::endl; + std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " " + << BlockWaves::at(ck_tile::number<1>{}) << std::endl; + std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); + + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, x_host); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/36_copy/test_copy.hpp b/example/ck_tile/36_copy/test_copy.hpp new file mode 100644 index 0000000000..8fed22a3d0 --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WaveTile, // warp size, seq + typename Vector> // contiguous elements(vector size) along seq +struct TileCopyShape +{ + // We split Workgroup waves into two specialized groups. + // One for reading data from global -> LDS, the other is doing reduction + static constexpr index_t WaveGroups = 2; + static constexpr index_t MWarps = BlockWaves::at(number<0>{}); + static constexpr index_t NWarps = BlockWaves::at(number<0>{}); + + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WaveTile::at(number<0>{}); + static constexpr index_t Warp_N = WaveTile::at(number<1>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t WarpPerBlock_M = + integer_divide_ceil(BlockWaves::at(number<0>{}), WaveGroups); + static constexpr index_t WarpPerBlock_N = + integer_divide_ceil(BlockWaves::at(number<1>{}), WaveGroups); + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{}); + + static constexpr index_t BlockSize = get_warp_size() * WaveNum; + static constexpr index_t WaveGroupSize = WaveNum / WaveGroups; + static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!"); +}; + +template +struct TileCopyProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +template +struct TileCopy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Vector_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // no. of active warps working in this thread block. + constexpr index_t Y1 = warp_size / X0; // no. of threads in a warp needed along M dimension. + constexpr index_t Y2 = + S::Warp_M / + (Y1 * + Y0); // no. of iterations each warp needs to perform to cover the entire tile window. + + constexpr auto outer_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}; + return make_static_tile_distribution(outer_encoding); + } + + CK_TILE_DEVICE void + operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + + // LDS Data. + __shared__ XDataType x_lds[number{} * number{}]; + XDataType* __restrict__ p_x_lds = static_cast(x_lds); + + const auto x_lds_desc = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, 1), + number{}, + number<1>{}); + + auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_desc, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform( + make_tuple(number{} / S::Vector_N, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = make_tensor_view(p_x_lds, x_lds_block_desc); + + auto x_block_lds_window = + make_tile_window(x_lds_view, + make_tuple(number{}, number{}), + {0, 0}, + MakeDRAMDistribution()); + auto x_block_lds_window_no_dist = make_tile_window( + x_lds_view, make_tuple(number{}, number{}), {0, 0}); + + // Input tensor + const auto iM = get_block_id() * S::Block_M; + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + MakeDRAMDistribution()); + + // Output tensor + const auto y_m = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + auto y_block_window = + make_tile_window(y_m, make_tuple(number{}, number{}), {iM, 0}); + + // Programming logic + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + auto my_id = get_warp_id(); + + auto DramTileDist = x_block_window.get_tile_distribution(); + using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + dram_reg_tile dram_tile; + + if(my_id == warp_id) + { + // load from DRAM to registers + load_tile(dram_tile, x_block_window); + + // store in lds + store_tile(x_block_lds_window_no_dist, dram_tile); + + // read from lds to registers + load_tile(dram_tile, x_block_lds_window); + + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } + __syncthreads(); + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 88efe0d8d9..d479cd35f6 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -19,3 +19,4 @@ add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) +add_subdirectory(36_copy)