diff --git a/example/ck_tile/39_copy/README.md b/example/ck_tile/39_copy/README.md index f45fcb682b..fa98cc1de6 100644 --- a/example/ck_tile/39_copy/README.md +++ b/example/ck_tile/39_copy/README.md @@ -38,14 +38,14 @@ The CK Tile framework is built around four key architectural components that wor Defines the **hierarchical tile structure** and **memory layout** of the kernel: ```cpp -using Shape = ck_tile::TileCopyShape; +using Shape = ck_tile::TileCopyShape; ``` **Components:** - **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N) - **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`) - **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`) -- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements) +- **ThreadTile**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements) **Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks. @@ -91,7 +91,7 @@ Defines the **execution flow** and **memory movement patterns**: ```cpp // Complete kernel definition -using Shape = ck_tile::TileCopyShape; +using Shape = ck_tile::TileCopyShape; using Problem = ck_tile::TileCopyProblem; using Policy = ck_tile::TileCopyPolicy; using Kernel = ck_tile::TileCopyKernel; @@ -113,7 +113,7 @@ using Kernel = ck_tile::TileCopyKernel; #### **Reusability** - Same **Shape** can be used with different **Problems** -- Same **Policy** can be applied to different **Shapes** +- Same **Policy** can be applied to different **Problems** - **Pipelines** can be reused across different kernels #### **Performance Optimization** @@ -127,16 +127,16 @@ using Kernel = ck_tile::TileCopyKernel; The CK Tile framework organizes work in a hierarchical manner: -1. **Vector**: Number of contiguous elements processed by a single thread +1. **ThreadTile**: Number of contiguous elements processed by a single thread - Enables vectorized memory loads/stores. - - Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension - - A Vector can be imagined as a thread-level tile + - Example: `ThreadTile = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension + - A ThreadTile can be imagined as a thread-level tile -2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD) - - Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize` +2. **WaveTile**: Number of elements covered by a single wave (64 threads on CDNA, 32 threads on RDNA) + - Must satisfy: `Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize` - This ensures the number of threads needed equals the wave size - - Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means: - - Each thread handles 4 elements (Vector_N = 4) + - Example: `WaveTile = seq<64, 4>` with `ThreadTile = seq<1, 4>` means: + - Each thread handles 4 elements (ThreadTile_N = 4) - Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements - Total elements = 256, which requires WaveSize = 64 threads @@ -144,8 +144,9 @@ The CK Tile framework organizes work in a hierarchical manner: - Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements 4. **BlockWaves**: Number of concurrent waves active in a block - - Usually 4 waves per block on modern AMD GPUs - - Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N + - Typical: 4 waves for heavy workloads (e.g., GEMM) + - Limit: up to 1024 threads per block → up to 16 waves (CDNA) or 32 waves (RDNA) + - Example: `BlockWaves = seq<4, 1>` means 4 waves along M, 1 along N ### Wave Repetition @@ -159,7 +160,7 @@ static constexpr index_t WaveRepetitionPerBlock_N = Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N); ``` -**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`. +**Key Insight**: When waves repeat, the effective work per thread becomes `ThreadTile * Repeat`, not just `ThreadTile`. ## Tile Distribution Encoding @@ -183,8 +184,9 @@ constexpr auto outer_encoding = - M2: Number of threads per wave along M - **N0, N1**: Distribution along N dimension - N0: Number of threads along N - - N1: Vector size (elements per thread) -- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat` + - N1: ThreadTile size (elements per thread) +- **Order and layout**: The inner-most (rightmost) dimension is the fastest-changing. Choosing `N1 = ThreadTile_N` maps vector width to contiguous addresses, i.e., row-major access in this example. +- **YIELD arguments**: Both `Repeat` and `ThreadTile` because effective work per thread is `ThreadTile * Repeat` ## Tensor Abstractions @@ -194,7 +196,7 @@ Defines the logical structure of a tensor: auto desc = make_naive_tensor_descriptor( make_tuple(M, N), // tensor dimensions make_tuple(N, 1), // strides - number{}, // vector length for vectorized access + number{}, // per-thread vector length number<1>{} // guaranteed last dimension vector stride ); ``` @@ -206,7 +208,7 @@ auto x_m_n = make_naive_tensor_view( p_x, // memory buffer make_tuple(M, N), // dimensions make_tuple(N, 1), // strides - number{}, // vector length + number{}, // per-thread vector length number<1>{} // guaranteed last dimension vector stride ); ``` @@ -247,10 +249,10 @@ struct TileCopyKernel 1. **Tensor View Creation**: ```cpp const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); ``` - Creates views for both input and output tensors - - Specifies vectorized access with `Vector_N` elements per load + - Specifies vectorized access with `ThreadTile_N` elements per load 2. **Tile Window Creation**: ```cpp diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/example/ck_tile/39_copy/copy_basic.cpp index d46add879c..460036a641 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/example/ck_tile/39_copy/copy_basic.cpp @@ -54,7 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); // Define tile configuration - using Vector = ck_tile::sequence<1, 4>; // vector size along M and N dimension + using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension @@ -65,7 +65,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl; // Define kernel types - using Shape = ck_tile::TileCopyShape; + using Shape = ck_tile::TileCopyShape; using Problem = ck_tile::TileCopyProblem; using Policy = ck_tile::TileCopyPolicy; using Kernel = ck_tile::ElementWiseTileCopyKernel; @@ -88,8 +88,9 @@ bool run(const ck_tile::ArgParser& arg_parser) << " " << BlockTile::at(ck_tile::number<1>{}) << std::endl; std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{}) << " " << WaveTile::at(ck_tile::number<1>{}) << std::endl; - std::cout << "vector (number of elements per thread) " << Vector::at(ck_tile::number<0>{}) - << " " << Vector::at(ck_tile::number<1>{}) << std::endl; + std::cout << "thread tile (number of elements per thread) " + << ThreadTile::at(ck_tile::number<0>{}) << " " << ThreadTile::at(ck_tile::number<1>{}) + << std::endl; std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> (" << Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M << ")" << std::endl; diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/example/ck_tile/39_copy/copy_basic.hpp index bbeb964fda..1a313e1353 100644 --- a/example/ck_tile/39_copy/copy_basic.hpp +++ b/example/ck_tile/39_copy/copy_basic.hpp @@ -17,14 +17,14 @@ namespace ck_tile { * @tparam BlockWaves Number of waves along seq * @tparam BlockTile Block size, seq * @tparam WaveTile Wave size, seq - * @tparam Vector Contiguous elements (vector size) along seq + * @tparam ThreadTile Contiguous elements per thread along seq */ -template +template struct TileCopyShape { - // Vector dimensions for memory operations - static constexpr index_t Vector_M = Vector::at(number<0>{}); - static constexpr index_t Vector_N = Vector::at(number<1>{}); + // ThreadTile dimensions for memory operations + static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{}); + static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{}); // Wave tile dimensions static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{}); @@ -51,7 +51,7 @@ struct TileCopyShape // Configuration validation static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive"); static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive"); - static_assert(Vector_M > 0 && Vector_N > 0, "Vector dimensions must be positive"); + static_assert(ThreadTile_M > 0 && ThreadTile_N > 0, "ThreadTile dimensions must be positive"); static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0, "Waves per block must be positive"); static_assert(Waves_Per_Block_M * Wave_Tile_M > 0, @@ -60,8 +60,8 @@ struct TileCopyShape "Invalid wave configuration for N dimension"); // Ensure wave tile dimensions align with wave size - static_assert(Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize, - "(Wave_Tile_M/Vector_M) * (Wave_Tile_N/Vector_N) != WaveSize"); + static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize, + "(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize"); }; /** @@ -95,7 +95,7 @@ struct TileCopyPolicy constexpr index_t block_size = S::BlockSize; // Distribution calculation to ensure all threads participate - constexpr index_t N1 = S::Vector_N; // Elements per thread along N + constexpr index_t N1 = S::ThreadTile_N; // Elements per thread along N constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N constexpr index_t M2 = wave_size / N0; // Threads per wave along M @@ -143,23 +143,21 @@ struct TileCopyKernel // Create tensor views for input and output const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); const auto y_m_n = make_naive_tensor_view( - p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); // Create tile windows with DRAM distribution - auto x_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}, - Policy::template MakeDRAMDistribution()); + auto x_window = make_tile_window(x_m_n, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); - auto y_window = - make_tile_window(y_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}, - Policy::template MakeDRAMDistribution()); + auto y_window = make_tile_window(y_m_n, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); // Calculate iterations needed to cover N dimension // Note: This kernel uses data parallelism only in the M dimension. @@ -218,23 +216,21 @@ struct ElementWiseTileCopyKernel // Create tensor views for input and output const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); const auto y_m_n = make_naive_tensor_view( - p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); // Create tile windows with DRAM distribution - auto x_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}, - Policy::template MakeDRAMDistribution()); + auto x_window = make_tile_window(x_m_n, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); - auto y_window = - make_tile_window(y_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}, - Policy::template MakeDRAMDistribution()); + auto y_window = make_tile_window(y_m_n, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); // Calculate iterations needed to cover N dimension // Note: This kernel uses data parallelism only in the M dimension. @@ -297,45 +293,41 @@ struct TileCopyKernel_LDS } // LDS buffer allocation - __shared__ XDataType x_lds_buffer[S::Block_Tile_M * S::Block_Tile_N]; + __shared__ XDataType x_lds_buffer[S::Block_Tile_Mmake * S::Block_Tile_N]; // LDS tensor descriptor and view const auto x_lds_descriptor = make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N), make_tuple(S::Block_Tile_N, 1), - number{}, + number{}, number<1>{}); auto x_lds_view = make_tensor_view(x_lds_buffer, x_lds_descriptor); // LDS windows with different distributions for optimal access patterns - auto x_lds_write_window = make_tile_window( - x_lds_view, make_tuple(number{}, number{}), {0, 0}); + auto x_lds_write_window = + make_tile_window(x_lds_view, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {0, 0}); - auto x_lds_read_window = - make_tile_window(x_lds_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeDRAMDistribution()); + auto x_lds_read_window = make_tile_window(x_lds_view, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {0, 0}, + Policy::template MakeDRAMDistribution()); // Global memory tensor views const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); const auto y_m_n = make_naive_tensor_view( - p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); // Global memory tile windows - auto x_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}, - Policy::template MakeDRAMDistribution()); + auto x_window = make_tile_window(x_m_n, + make_tuple(S::Block_Tile_M, S::Block_Tile_N), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); - auto y_window = - make_tile_window(y_m_n, - make_tuple(number{}, number{}), - {tile_block_origin_m, 0}); + auto y_window = make_tile_window( + y_m_n, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {tile_block_origin_m, 0}); // Calculate iterations needed to cover N dimension // Note: This kernel uses data parallelism only in the M dimension. diff --git a/example/ck_tile/39_copy/test_tile_example.sh b/example/ck_tile/39_copy/test_tile_example.sh new file mode 100755 index 0000000000..fcd8c8e991 --- /dev/null +++ b/example/ck_tile/39_copy/test_tile_example.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +set -euo pipefail + +BIN="${BIN:-../../../build/bin/tile_example_copy}" +WARMUP="${WARMUP:-20}" +REPEAT="${REPEAT:-100}" +VALIDATE="${VALIDATE:-1}" + +MS=(128 256 512 1024) +NS=(64 256 1024 2048 4096) +PRECS=(fp16 fp32) + +echo "Using BIN=$BIN" +echo "WARMUP=$WARMUP REPEAT=$REPEAT VALIDATE=$VALIDATE" + +failures=0 + +for prec in "${PRECS[@]}"; do + for m in "${MS[@]}"; do + for n in "${NS[@]}"; do + echo "==============================================" + echo "Running: prec=$prec m=$m n=$n" + set +e + out="$("$BIN" -prec="$prec" -m="$m" -n="$n" -warmup="$WARMUP" -repeat="$REPEAT" -v="$VALIDATE" 2>&1)" + rc=$? + set -e + + echo "$out" + if [[ $rc -ne 0 ]]; then + echo "RUN ERROR (rc=$rc) for m=$m n=$n prec=$prec" + ((failures++)) || true + continue + fi + + if [[ "$VALIDATE" == "1" ]]; then + if ! grep -q "valid:y" <<<"$out"; then + echo "VALIDATION FAILED for m=$m n=$n prec=$prec" + ((failures++)) || true + fi + fi + done + done +done + +echo "==============================================" +if [[ $failures -eq 0 ]]; then + echo "All runs passed" +else + echo "$failures runs failed" +fi \ No newline at end of file diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 69f645b850..16fde15c7b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -45,6 +45,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" @@ -52,8 +54,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c9bedd7c53..e792820466 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -8,6 +8,8 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"