diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a21634b7d..9c942a776d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added +* Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). diff --git a/Jenkinsfile b/Jenkinsfile index c0efaa3b91..590ee92e90 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -801,6 +801,7 @@ def run_aiter_tests(Map conf=[:]){ sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" } catch(e){ echo "Throwing error exception while running AITER tests" diff --git a/example/ck_tile/39_copy/CMakeLists.txt b/example/ck_tile/39_copy/CMakeLists.txt new file mode 100644 index 0000000000..98397a33d2 --- /dev/null +++ b/example/ck_tile/39_copy/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp) + +# Impact: This flag ensures that the compiler doesn't make +# assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. +target_compile_options(tile_example_copy PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) diff --git a/example/ck_tile/39_copy/README.md b/example/ck_tile/39_copy/README.md new file mode 100644 index 0000000000..f45fcb682b --- /dev/null +++ b/example/ck_tile/39_copy/README.md @@ -0,0 +1,313 @@ +# CK Tile Framework: Getting Started with Tile Copy Operations + +## Overview + +### Copy Kernel +A minimal CK_Tile memory copy implementation demonstrating the basic setup required to write a kernel in CK Tile. +This experimental kernel is intended for novice CK developers. It introduces the building blocks of CK Tile and provides a sandbox for experimenting with kernel parameters. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture +# (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the copy kernel executable +make tile_example_copy -j +``` +This will result in an executable `build/bin/test_copy_basic` + +## example +``` +args: + -m input matrix rows. (default 64) + -n input matrix cols. (default 8) + -id wave to use for computation. (default 0) + -v validation flag to check device results. (default 1) + -prec datatype precision to use. (default fp16) + -warmup no. of warmup iterations. (default 50) + -repeat no. of iterations for kernel execution time. (default 100) +``` + +## CK Tile Architecture Components + +The CK Tile framework is built around four key architectural components that work together to define and execute GPU kernels: shape, policy, problem, and pipeline. + +### **1. Shape** +Defines the **hierarchical tile structure** and **memory layout** of the kernel: + +```cpp +using Shape = ck_tile::TileCopyShape; +``` + +**Components:** +- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N) +- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`) +- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`) +- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements) + +**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks. + +### **2. Problem** +Defines the **data types** and **kernel configuration**: + +```cpp +using Problem = ck_tile::TileCopyProblem; +``` + +**Components:** +- **XDataType**: Input/output data type (e.g., `float`, `half`) +- **Shape**: The tile shape defined above + +**Purpose**: Encapsulates **what** the kernel operates on and **how** it's configured. + +### **3. Policy** +Defines the **memory access patterns** and **distribution strategies**: + +```cpp +using Policy = ck_tile::TileCopyPolicy; +``` + +**Key Functions:** +- **MakeDRAMDistribution()**: Defines how threads access DRAM memory. + +**Purpose**: Defines **how** data is accessed and distributed across threads. + +### **4. Pipeline** +Defines the **execution flow** and **memory movement patterns**: + +```cpp +// Example pipeline stages: +// 1. DRAM → Registers (load_tile) +// 2. Registers → LDS (store_tile) +// 3. LDS → Registers (load_tile with distribution) +// 4. Registers → DRAM (store_tile) +``` + +**Purpose**: Defines the **sequence of operations** and **memory movement strategy**. + +### **Component Interaction** + +```cpp +// Complete kernel definition +using Shape = ck_tile::TileCopyShape; +using Problem = ck_tile::TileCopyProblem; +using Policy = ck_tile::TileCopyPolicy; +using Kernel = ck_tile::TileCopyKernel; +``` + +**Flow:** +1. **Shape** defines the tile structure and work distribution +2. **Problem** combines data types with the shape +3. **Policy** defines memory access patterns for the problem +4. **Kernel** implements the actual computation using all components + +### **Why This Architecture?** + +#### **Separation of Concerns** +- **Shape**: Focuses on **work distribution** and **tile structure** +- **Problem**: Focuses on **data types** and **configuration** +- **Policy**: Focuses on **memory access** and **optimization** +- **Pipeline**: Focuses on **execution flow** and **synchronization** + +#### **Reusability** +- Same **Shape** can be used with different **Problems** +- Same **Policy** can be applied to different **Shapes** +- **Pipelines** can be reused across different kernels + +#### **Performance Optimization** +- **Shape** enables optimal work distribution +- **Policy** enables optimal memory access patterns +- **Pipeline** enables optimal execution flow + +## Core Concepts + +### Hierarchical Tile Structure + +The CK Tile framework organizes work in a hierarchical manner: + +1. **Vector**: Number of contiguous elements processed by a single thread + - Enables vectorized memory loads/stores. + - Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension + - A Vector can be imagined as a thread-level tile + +2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD) + - Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize` + - This ensures the number of threads needed equals the wave size + - Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means: + - Each thread handles 4 elements (Vector_N = 4) + - Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements + - Total elements = 256, which requires WaveSize = 64 threads + +3. **BlockTile**: Number of elements covered by one block (typically mapped to one CU) + - Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements + +4. **BlockWaves**: Number of concurrent waves active in a block + - Usually 4 waves per block on modern AMD GPUs + - Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N + +### Wave Repetition + +In many scenarios, the total work (BlockTile) is larger than what the available waves can cover in a single iteration. This requires **wave repetition**: + +```cpp +// Calculate how many times a wave needs to repeat to cover the entire block tile +static constexpr index_t WaveRepetitionPerBlock_M = + Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M); +static constexpr index_t WaveRepetitionPerBlock_N = + Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N); +``` + +**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`. + +## Tile Distribution Encoding + +The tile distribution encoding specifies how work is distributed across threads: + +```cpp +constexpr auto outer_encoding = + tile_distribution_encoding, // replication + tuple, sequence>, // hierarchy + tuple, sequence<1, 2>>, // parallelism + tuple, sequence<2, 0>>, // paralleism + sequence<1, 2>, // yield + sequence<0, 1>>{}; // yield +``` + +### Encoding Parameters Explained + +- **M0, M1, M2**: Hierarchical distribution along M dimension + - M0: Number of wave iterations along M + - M1: Number of waves along M + - M2: Number of threads per wave along M +- **N0, N1**: Distribution along N dimension + - N0: Number of threads along N + - N1: Vector size (elements per thread) +- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat` + +## Tensor Abstractions + +### Tensor Descriptor +Defines the logical structure of a tensor: +```cpp +auto desc = make_naive_tensor_descriptor( + make_tuple(M, N), // tensor dimensions + make_tuple(N, 1), // strides + number{}, // vector length for vectorized access + number<1>{} // guaranteed last dimension vector stride +); +``` + +### Tensor View +Combines memory buffer with tensor descriptor: +```cpp +auto x_m_n = make_naive_tensor_view( + p_x, // memory buffer + make_tuple(M, N), // dimensions + make_tuple(N, 1), // strides + number{}, // vector length + number<1>{} // guaranteed last dimension vector stride +); +``` + +### Tile Window +A view into a specific tile of the tensor with thread distribution: +```cpp +auto x_window = make_tile_window( + x_m_n, // tensor view + make_tuple(Block_Tile_M, Block_Tile_N), // tile size + {iM, 0}, // tile origin + tile_distribution // how work is distributed among threads +); +``` + +## The test_copy_basic Kernel + +### Kernel Structure + +The `TileCopyKernel` implements a basic copy operation from input tensor `x` to output tensor `y`: + +```cpp +template +struct TileCopyKernel +{ + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + // 1. Create tensor views + // 2. Create tile windows + // 3. Iterate over N dimension tiles + // 4. Load, copy, and store data + } +}; +``` + +### Step-by-Step Execution + +1. **Tensor View Creation**: + ```cpp + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + ``` + - Creates views for both input and output tensors + - Specifies vectorized access with `Vector_N` elements per load + +2. **Tile Window Creation**: + ```cpp + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeDRAMDistribution()); + ``` + - Creates windows into specific tiles of the tensors + - Each block processes one tile starting at `{iM, 0}` + - Tile distribution determines how threads access data + +3. **N-Dimension Iteration**: + ```cpp + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + ``` + - If tensor N dimension > Block_Tile_N, multiple iterations are needed + - Each iteration processes one tile along N dimension + +4. **Load-Store Operations**: + ```cpp + dram_reg_tile dram_tile; + load_tile(dram_tile, x_window); // Load from global memory to registers + store_tile(y_window, dram_tile); // Store from registers to global memory + move_tile_window(x_window, {0, S::Block_Tile_N}); // Move to next N tile + move_tile_window(y_window, {0, S::Block_Tile_N}); + ``` + +### How Load/Store Works + +1. **Load Tile**: + - Each thread loads its assigned elements based on tile distribution + - Vectorized loads enable efficient memory bandwidth utilization + - Data is distributed to per-thread register buffers + +2. **Store Tile**: + - Each thread writes its assigned elements back to global memory + - Maintains the same distribution pattern as load + +3. **Tile Window Movement**: + - Moves the window to the next tile along N dimension + - Enables processing of large tensors that don't fit in one tile + +## Memory Access Patterns + +### Vectorized Access +- Enabled by specifying vector length in tensor views +- Each thread loads/stores multiple contiguous elements in one operation +- Improves memory bandwidth utilization + +### Thread Distribution +- Tile distribution encoding determines which threads access which elements +- Ensures all threads participate and no data is missed +- Enables memory coalescing for optimal performance + +### Coordinate Transform (Embed) +- Maps multi-dimensional tensor indices to linear memory addresses +- Handles stride calculations automatically +- Enables efficient access to non-contiguous memory layouts diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/example/ck_tile/39_copy/copy_basic.cpp new file mode 100644 index 0000000000..d46add879c --- /dev/null +++ b/example/ck_tile/39_copy/copy_basic.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include +#include "copy_basic.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "128", "m dimension") + .insert("n", "8", "n dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision(fp16 or fp32)") + .insert("warmup", "50", "cold iter") + .insert("repeat", "100", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Create host tensors + ck_tile::HostTensor x_host({m, n}); // input matrix + ck_tile::HostTensor y_host_ref({m, n}); // reference output matrix + ck_tile::HostTensor y_host_dev({m, n}); // device output matrix + + // Initialize input data with increasing values + ck_tile::half_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + x_host(i, j) = value++; + } + } + + // Allocate device memory + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + // Define tile configuration + using Vector = ck_tile::sequence<1, 4>; // vector size along M and N dimension + using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension + using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension + using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension + + // Calculate grid size + ck_tile::index_t kGridSize = + ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl; + + // Define kernel types + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Policy = ck_tile::TileCopyPolicy; + using Kernel = ck_tile::ElementWiseTileCopyKernel; + // using Kernel = ck_tile::TileCopyKernel; + // using Kernel = ck_tile::TileCopyKernel_LDS; + + // question: Why do we not have a pipeline? + // answer: For basic copy operation, pipeline is not needed. + // we intentionally do not use pipeline for this example and let the kernel be composite of + // Problem and Policy + + constexpr ck_tile::index_t kBlockSize = Shape::BlockSize; + + // Print configuration information + std::cout << "block size (number of threads per block) " << kBlockSize << std::endl; + std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl; + std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{}) + << " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl; + std::cout << "block tile (number of elements per block) " << BlockTile::at(ck_tile::number<0>{}) + << " " << BlockTile::at(ck_tile::number<1>{}) << std::endl; + std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{}) + << " " << WaveTile::at(ck_tile::number<1>{}) << std::endl; + std::cout << "vector (number of elements per thread) " << Vector::at(ck_tile::number<0>{}) + << " " << Vector::at(ck_tile::number<1>{}) << std::endl; + std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> (" + << Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M + << ")" << std::endl; + std::cout << "WaveRepetitionPerBlock_N = " << Shape::WaveRepetitionPerBlock_N << " --> (" + << Shape::Block_Tile_N << "/" << Shape::Waves_Per_Block_N << "*" << Shape::Wave_Tile_N + << ")" << std::endl; + + // Launch kernel + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); + + // Calculate and print performance metrics + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // Copy results back to host + y_buf.FromDevice(y_host_dev.mData.data()); + // Use exact equality (tolerance = 0) for copy operations since copy should be exact + pass = ck_tile::check_err(y_host_dev, x_host, "Error: Copy operation failed!", 0.0, 0.0); + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + // Print results for debugging + // std::cout << "Input matrix (x_host):" << std::endl; + // std::cout << x_host << std::endl; + // std::cout << "Output matrix (y_host_dev):" << std::endl; + // std::cout << y_host_dev << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + if(arg_parser.get_str("prec") == "fp16") + return run(arg_parser) ? 0 : -2; + else + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/example/ck_tile/39_copy/copy_basic.hpp new file mode 100644 index 0000000000..bbeb964fda --- /dev/null +++ b/example/ck_tile/39_copy/copy_basic.hpp @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +/** + * @brief Tile copy shape configuration + * + * @tparam BlockWaves Number of waves along seq + * @tparam BlockTile Block size, seq + * @tparam WaveTile Wave size, seq + * @tparam Vector Contiguous elements (vector size) along seq + */ +template +struct TileCopyShape +{ + // Vector dimensions for memory operations + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + // Wave tile dimensions + static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{}); + static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{}); + + // Block tile dimensions + static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_Tile_N = BlockTile::at(number<1>{}); + + // Waves per block configuration + static constexpr index_t Waves_Per_Block_M = BlockWaves::at(number<0>{}); + static constexpr index_t Waves_Per_Block_N = BlockWaves::at(number<1>{}); + + // Calculate wave repetition to cover entire block tile + static constexpr index_t WaveRepetitionPerBlock_M = + Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M); + static constexpr index_t WaveRepetitionPerBlock_N = + Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N); + + // Hardware configuration + static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize; + + // Configuration validation + static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive"); + static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive"); + static_assert(Vector_M > 0 && Vector_N > 0, "Vector dimensions must be positive"); + static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0, + "Waves per block must be positive"); + static_assert(Waves_Per_Block_M * Wave_Tile_M > 0, + "Invalid wave configuration for M dimension"); + static_assert(Waves_Per_Block_N * Wave_Tile_N > 0, + "Invalid wave configuration for N dimension"); + + // Ensure wave tile dimensions align with wave size + static_assert(Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize, + "(Wave_Tile_M/Vector_M) * (Wave_Tile_N/Vector_N) != WaveSize"); +}; + +/** + * @brief Problem definition for tile copy operation + */ +template +struct TileCopyProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +/** + * @brief Policy for tile copy operation + */ +template +struct TileCopyPolicy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + + /** + * @brief Create DRAM distribution for optimal memory access + */ + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t wave_size = S::WaveSize; + constexpr index_t block_size = S::BlockSize; + + // Distribution calculation to ensure all threads participate + constexpr index_t N1 = S::Vector_N; // Elements per thread along N + constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N + + constexpr index_t M2 = wave_size / N0; // Threads per wave along M + constexpr index_t M1 = block_size / wave_size; // Waves possible along M + constexpr index_t M0 = S::Block_Tile_M / (M1 * M2); // Wave iterations along M + + // Validate complete coverage + static_assert(M0 * M1 * M2 * N0 * N1 == S::Block_Tile_M * S::Block_Tile_N, + "Tile distribution must cover entire block tile"); + + constexpr auto outer_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}; + return make_static_tile_distribution(outer_encoding); + } +}; + +/** + * @brief Direct copy kernel from global memory to global memory + */ +template +struct TileCopyKernel +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate tile block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // Create tensor views for input and output + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Create tile windows with DRAM distribution + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Get tile distribution for register tensor + auto DramTileDist = x_window.get_tile_distribution(); + using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); + + // Main copy loop - processes N dimension tiles sequentially within each block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + dram_reg_tile dram_tile; + + // Direct copy implementation + load_tile(dram_tile, x_window); + store_tile(y_window, dram_tile); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +/** + * @brief Element-wise copy kernel for data transformation scenarios + * + * This kernel performs element-wise copy operations, allowing for data transformation + * during the copy process. Useful when data needs to be processed or converted + * between different formats. + */ +template +struct ElementWiseTileCopyKernel +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // Create tensor views for input and output + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Create tile windows with DRAM distribution + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Main element-wise copy loop - processes N dimension tiles sequentially within each block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + // Element-wise copy implementation for data transformation + const auto xa = load_tile(x_window); + auto y_compute = load_tile(y_window); + + constexpr auto spans = decltype(xa)::get_distributed_spans(); + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + const auto x = ck_tile::type_convert(xa[i_j_idx]); + y_compute(i_j_idx) = x; + }); + }); + + store_tile(y_window, y_compute); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +/** + * @brief LDS-based copy kernel for data processing scenarios + * + * This kernel copies data from global memory to LDS and then to global memory, + * useful when data needs to be processed or transformed during the copy operation. + */ +template +struct TileCopyKernel_LDS +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // LDS buffer allocation + __shared__ XDataType x_lds_buffer[S::Block_Tile_M * S::Block_Tile_N]; + + // LDS tensor descriptor and view + const auto x_lds_descriptor = + make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N), + make_tuple(S::Block_Tile_N, 1), + number{}, + number<1>{}); + + auto x_lds_view = make_tensor_view(x_lds_buffer, x_lds_descriptor); + + // LDS windows with different distributions for optimal access patterns + auto x_lds_write_window = make_tile_window( + x_lds_view, make_tuple(number{}, number{}), {0, 0}); + + auto x_lds_read_window = + make_tile_window(x_lds_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeDRAMDistribution()); + + // Global memory tensor views + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Global memory tile windows + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Main copy loop with LDS staging - processes N dimension tiles sequentially within each + // block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + // Global memory to LDS + auto dram_tile = load_tile(x_window); + store_tile(x_lds_write_window, dram_tile); + + // Synchronize LDS access + block_sync_lds(); + + // LDS to global memory + auto lds_tile = load_tile(x_lds_read_window); + store_tile(y_window, lds_tile); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 630b96ede0..8fce70ba04 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) +add_subdirectory(39_copy)