mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Minor Improvements in CK TILE memory copy EXAMPLE (#2678)
* Rename vector to ThreadTile * more notes on tile encoding * remove number<> from tuple of make_tile_window * add script to stress test the copy example
This commit is contained in:
@@ -38,14 +38,14 @@ The CK Tile framework is built around four key architectural components that wor
|
||||
Defines the **hierarchical tile structure** and **memory layout** of the kernel:
|
||||
|
||||
```cpp
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
|
||||
```
|
||||
|
||||
**Components:**
|
||||
- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N)
|
||||
- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`)
|
||||
- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`)
|
||||
- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
|
||||
- **ThreadTile**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
|
||||
|
||||
**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks.
|
||||
|
||||
@@ -91,7 +91,7 @@ Defines the **execution flow** and **memory movement patterns**:
|
||||
|
||||
```cpp
|
||||
// Complete kernel definition
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Policy = ck_tile::TileCopyPolicy<Problem>;
|
||||
using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
|
||||
@@ -113,7 +113,7 @@ using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
|
||||
|
||||
#### **Reusability**
|
||||
- Same **Shape** can be used with different **Problems**
|
||||
- Same **Policy** can be applied to different **Shapes**
|
||||
- Same **Policy** can be applied to different **Problems**
|
||||
- **Pipelines** can be reused across different kernels
|
||||
|
||||
#### **Performance Optimization**
|
||||
@@ -127,16 +127,16 @@ using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
|
||||
|
||||
The CK Tile framework organizes work in a hierarchical manner:
|
||||
|
||||
1. **Vector**: Number of contiguous elements processed by a single thread
|
||||
1. **ThreadTile**: Number of contiguous elements processed by a single thread
|
||||
- Enables vectorized memory loads/stores.
|
||||
- Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
|
||||
- A Vector can be imagined as a thread-level tile
|
||||
- Example: `ThreadTile = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
|
||||
- A ThreadTile can be imagined as a thread-level tile
|
||||
|
||||
2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD)
|
||||
- Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize`
|
||||
2. **WaveTile**: Number of elements covered by a single wave (64 threads on CDNA, 32 threads on RDNA)
|
||||
- Must satisfy: `Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize`
|
||||
- This ensures the number of threads needed equals the wave size
|
||||
- Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means:
|
||||
- Each thread handles 4 elements (Vector_N = 4)
|
||||
- Example: `WaveTile = seq<64, 4>` with `ThreadTile = seq<1, 4>` means:
|
||||
- Each thread handles 4 elements (ThreadTile_N = 4)
|
||||
- Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements
|
||||
- Total elements = 256, which requires WaveSize = 64 threads
|
||||
|
||||
@@ -144,8 +144,9 @@ The CK Tile framework organizes work in a hierarchical manner:
|
||||
- Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements
|
||||
|
||||
4. **BlockWaves**: Number of concurrent waves active in a block
|
||||
- Usually 4 waves per block on modern AMD GPUs
|
||||
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N
|
||||
- Typical: 4 waves for heavy workloads (e.g., GEMM)
|
||||
- Limit: up to 1024 threads per block → up to 16 waves (CDNA) or 32 waves (RDNA)
|
||||
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M, 1 along N
|
||||
|
||||
### Wave Repetition
|
||||
|
||||
@@ -159,7 +160,7 @@ static constexpr index_t WaveRepetitionPerBlock_N =
|
||||
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
|
||||
```
|
||||
|
||||
**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`.
|
||||
**Key Insight**: When waves repeat, the effective work per thread becomes `ThreadTile * Repeat`, not just `ThreadTile`.
|
||||
|
||||
## Tile Distribution Encoding
|
||||
|
||||
@@ -183,8 +184,9 @@ constexpr auto outer_encoding =
|
||||
- M2: Number of threads per wave along M
|
||||
- **N0, N1**: Distribution along N dimension
|
||||
- N0: Number of threads along N
|
||||
- N1: Vector size (elements per thread)
|
||||
- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat`
|
||||
- N1: ThreadTile size (elements per thread)
|
||||
- **Order and layout**: The inner-most (rightmost) dimension is the fastest-changing. Choosing `N1 = ThreadTile_N` maps vector width to contiguous addresses, i.e., row-major access in this example.
|
||||
- **YIELD arguments**: Both `Repeat` and `ThreadTile` because effective work per thread is `ThreadTile * Repeat`
|
||||
|
||||
## Tensor Abstractions
|
||||
|
||||
@@ -194,7 +196,7 @@ Defines the logical structure of a tensor:
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), // tensor dimensions
|
||||
make_tuple(N, 1), // strides
|
||||
number<Vector_N>{}, // vector length for vectorized access
|
||||
number<ThreadTile_N>{}, // per-thread vector length
|
||||
number<1>{} // guaranteed last dimension vector stride
|
||||
);
|
||||
```
|
||||
@@ -206,7 +208,7 @@ auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, // memory buffer
|
||||
make_tuple(M, N), // dimensions
|
||||
make_tuple(N, 1), // strides
|
||||
number<S::Vector_N>{}, // vector length
|
||||
number<S::ThreadTile_N>{}, // per-thread vector length
|
||||
number<1>{} // guaranteed last dimension vector stride
|
||||
);
|
||||
```
|
||||
@@ -247,10 +249,10 @@ struct TileCopyKernel
|
||||
1. **Tensor View Creation**:
|
||||
```cpp
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
```
|
||||
- Creates views for both input and output tensors
|
||||
- Specifies vectorized access with `Vector_N` elements per load
|
||||
- Specifies vectorized access with `ThreadTile_N` elements per load
|
||||
|
||||
2. **Tile Window Creation**:
|
||||
```cpp
|
||||
|
||||
Reference in New Issue
Block a user