Minor Improvements in CK TILE memory copy EXAMPLE (#2678)

* Rename vector to ThreadTile

* more notes on tile encoding

* remove number<> from tuple of make_tile_window

* add script to stress test the copy example
This commit is contained in:
Aviral Goel
2025-08-13 18:24:16 -04:00
committed by GitHub
parent bcc38deff7
commit 8a698c7445
6 changed files with 126 additions and 79 deletions

View File

@@ -38,14 +38,14 @@ The CK Tile framework is built around four key architectural components that wor
Defines the **hierarchical tile structure** and **memory layout** of the kernel:
```cpp
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
```
**Components:**
- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N)
- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`)
- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`)
- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
- **ThreadTile**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks.
@@ -91,7 +91,7 @@ Defines the **execution flow** and **memory movement patterns**:
```cpp
// Complete kernel definition
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Policy = ck_tile::TileCopyPolicy<Problem>;
using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
@@ -113,7 +113,7 @@ using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
#### **Reusability**
- Same **Shape** can be used with different **Problems**
- Same **Policy** can be applied to different **Shapes**
- Same **Policy** can be applied to different **Problems**
- **Pipelines** can be reused across different kernels
#### **Performance Optimization**
@@ -127,16 +127,16 @@ using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
The CK Tile framework organizes work in a hierarchical manner:
1. **Vector**: Number of contiguous elements processed by a single thread
1. **ThreadTile**: Number of contiguous elements processed by a single thread
- Enables vectorized memory loads/stores.
- Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
- A Vector can be imagined as a thread-level tile
- Example: `ThreadTile = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
- A ThreadTile can be imagined as a thread-level tile
2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD)
- Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize`
2. **WaveTile**: Number of elements covered by a single wave (64 threads on CDNA, 32 threads on RDNA)
- Must satisfy: `Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize`
- This ensures the number of threads needed equals the wave size
- Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means:
- Each thread handles 4 elements (Vector_N = 4)
- Example: `WaveTile = seq<64, 4>` with `ThreadTile = seq<1, 4>` means:
- Each thread handles 4 elements (ThreadTile_N = 4)
- Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements
- Total elements = 256, which requires WaveSize = 64 threads
@@ -144,8 +144,9 @@ The CK Tile framework organizes work in a hierarchical manner:
- Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements
4. **BlockWaves**: Number of concurrent waves active in a block
- Usually 4 waves per block on modern AMD GPUs
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N
- Typical: 4 waves for heavy workloads (e.g., GEMM)
- Limit: up to 1024 threads per block → up to 16 waves (CDNA) or 32 waves (RDNA)
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M, 1 along N
### Wave Repetition
@@ -159,7 +160,7 @@ static constexpr index_t WaveRepetitionPerBlock_N =
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
```
**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`.
**Key Insight**: When waves repeat, the effective work per thread becomes `ThreadTile * Repeat`, not just `ThreadTile`.
## Tile Distribution Encoding
@@ -183,8 +184,9 @@ constexpr auto outer_encoding =
- M2: Number of threads per wave along M
- **N0, N1**: Distribution along N dimension
- N0: Number of threads along N
- N1: Vector size (elements per thread)
- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat`
- N1: ThreadTile size (elements per thread)
- **Order and layout**: The inner-most (rightmost) dimension is the fastest-changing. Choosing `N1 = ThreadTile_N` maps vector width to contiguous addresses, i.e., row-major access in this example.
- **YIELD arguments**: Both `Repeat` and `ThreadTile` because effective work per thread is `ThreadTile * Repeat`
## Tensor Abstractions
@@ -194,7 +196,7 @@ Defines the logical structure of a tensor:
auto desc = make_naive_tensor_descriptor(
make_tuple(M, N), // tensor dimensions
make_tuple(N, 1), // strides
number<Vector_N>{}, // vector length for vectorized access
number<ThreadTile_N>{}, // per-thread vector length
number<1>{} // guaranteed last dimension vector stride
);
```
@@ -206,7 +208,7 @@ auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, // memory buffer
make_tuple(M, N), // dimensions
make_tuple(N, 1), // strides
number<S::Vector_N>{}, // vector length
number<S::ThreadTile_N>{}, // per-thread vector length
number<1>{} // guaranteed last dimension vector stride
);
```
@@ -247,10 +249,10 @@ struct TileCopyKernel
1. **Tensor View Creation**:
```cpp
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
```
- Creates views for both input and output tensors
- Specifies vectorized access with `Vector_N` elements per load
- Specifies vectorized access with `ThreadTile_N` elements per load
2. **Tile Window Creation**:
```cpp

View File

@@ -54,7 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.ToDevice(x_host.data());
// Define tile configuration
using Vector = ck_tile::sequence<1, 4>; // vector size along M and N dimension
using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N
using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension
using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension
using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension
@@ -65,7 +65,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl;
// Define kernel types
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Policy = ck_tile::TileCopyPolicy<Problem>;
using Kernel = ck_tile::ElementWiseTileCopyKernel<Problem, Policy>;
@@ -88,8 +88,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< " " << BlockTile::at(ck_tile::number<1>{}) << std::endl;
std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{})
<< " " << WaveTile::at(ck_tile::number<1>{}) << std::endl;
std::cout << "vector (number of elements per thread) " << Vector::at(ck_tile::number<0>{})
<< " " << Vector::at(ck_tile::number<1>{}) << std::endl;
std::cout << "thread tile (number of elements per thread) "
<< ThreadTile::at(ck_tile::number<0>{}) << " " << ThreadTile::at(ck_tile::number<1>{})
<< std::endl;
std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> ("
<< Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M
<< ")" << std::endl;

View File

@@ -17,14 +17,14 @@ namespace ck_tile {
* @tparam BlockWaves Number of waves along seq<M, N>
* @tparam BlockTile Block size, seq<M, N>
* @tparam WaveTile Wave size, seq<M, N>
* @tparam Vector Contiguous elements (vector size) along seq<M, N>
* @tparam ThreadTile Contiguous elements per thread along seq<M, N>
*/
template <typename BlockWaves, typename BlockTile, typename WaveTile, typename Vector>
template <typename BlockWaves, typename BlockTile, typename WaveTile, typename ThreadTile>
struct TileCopyShape
{
// Vector dimensions for memory operations
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
// ThreadTile dimensions for memory operations
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
// Wave tile dimensions
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
@@ -51,7 +51,7 @@ struct TileCopyShape
// Configuration validation
static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive");
static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive");
static_assert(Vector_M > 0 && Vector_N > 0, "Vector dimensions must be positive");
static_assert(ThreadTile_M > 0 && ThreadTile_N > 0, "ThreadTile dimensions must be positive");
static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0,
"Waves per block must be positive");
static_assert(Waves_Per_Block_M * Wave_Tile_M > 0,
@@ -60,8 +60,8 @@ struct TileCopyShape
"Invalid wave configuration for N dimension");
// Ensure wave tile dimensions align with wave size
static_assert(Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize,
"(Wave_Tile_M/Vector_M) * (Wave_Tile_N/Vector_N) != WaveSize");
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
};
/**
@@ -95,7 +95,7 @@ struct TileCopyPolicy
constexpr index_t block_size = S::BlockSize;
// Distribution calculation to ensure all threads participate
constexpr index_t N1 = S::Vector_N; // Elements per thread along N
constexpr index_t N1 = S::ThreadTile_N; // Elements per thread along N
constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N
constexpr index_t M2 = wave_size / N0; // Threads per wave along M
@@ -143,23 +143,21 @@ struct TileCopyKernel
// Create tensor views for input and output
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
// Create tile windows with DRAM distribution
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto x_window = make_tile_window(x_m_n,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window = make_tile_window(y_m_n,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.
@@ -218,23 +216,21 @@ struct ElementWiseTileCopyKernel
// Create tensor views for input and output
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
// Create tile windows with DRAM distribution
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto x_window = make_tile_window(x_m_n,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window = make_tile_window(y_m_n,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.
@@ -297,45 +293,41 @@ struct TileCopyKernel_LDS
}
// LDS buffer allocation
__shared__ XDataType x_lds_buffer[S::Block_Tile_M * S::Block_Tile_N];
__shared__ XDataType x_lds_buffer[S::Block_Tile_Mmake * S::Block_Tile_N];
// LDS tensor descriptor and view
const auto x_lds_descriptor =
make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N),
make_tuple(S::Block_Tile_N, 1),
number<S::Vector_N>{},
number<S::ThreadTile_N>{},
number<1>{});
auto x_lds_view = make_tensor_view<address_space_enum::lds>(x_lds_buffer, x_lds_descriptor);
// LDS windows with different distributions for optimal access patterns
auto x_lds_write_window = make_tile_window(
x_lds_view, make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}), {0, 0});
auto x_lds_write_window =
make_tile_window(x_lds_view, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {0, 0});
auto x_lds_read_window =
make_tile_window(x_lds_view,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{0, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto x_lds_read_window = make_tile_window(x_lds_view,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{0, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Global memory tensor views
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
// Global memory tile windows
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto x_window = make_tile_window(x_m_n,
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0});
auto y_window = make_tile_window(
y_m_n, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {tile_block_origin_m, 0});
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env bash
set -euo pipefail
BIN="${BIN:-../../../build/bin/tile_example_copy}"
WARMUP="${WARMUP:-20}"
REPEAT="${REPEAT:-100}"
VALIDATE="${VALIDATE:-1}"
MS=(128 256 512 1024)
NS=(64 256 1024 2048 4096)
PRECS=(fp16 fp32)
echo "Using BIN=$BIN"
echo "WARMUP=$WARMUP REPEAT=$REPEAT VALIDATE=$VALIDATE"
failures=0
for prec in "${PRECS[@]}"; do
for m in "${MS[@]}"; do
for n in "${NS[@]}"; do
echo "=============================================="
echo "Running: prec=$prec m=$m n=$n"
set +e
out="$("$BIN" -prec="$prec" -m="$m" -n="$n" -warmup="$WARMUP" -repeat="$REPEAT" -v="$VALIDATE" 2>&1)"
rc=$?
set -e
echo "$out"
if [[ $rc -ne 0 ]]; then
echo "RUN ERROR (rc=$rc) for m=$m n=$n prec=$prec"
((failures++)) || true
continue
fi
if [[ "$VALIDATE" == "1" ]]; then
if ! grep -q "valid:y" <<<"$out"; then
echo "VALIDATION FAILED for m=$m n=$n prec=$prec"
((failures++)) || true
fi
fi
done
done
done
echo "=============================================="
if [[ $failures -eq 0 ]]; then
echo "All runs passed"
else
echo "$failures runs failed"
fi