docs: add notes on tile distribution and inline comments (#3297)

* docs: add notes on tile distribution and inline comments

* Apply suggestions from code review

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

---------

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Aviral Goel
2025-12-11 10:47:19 +04:00
committed by GitHub
parent 8270900d60
commit fbbdd36ea8
5 changed files with 347 additions and 20 deletions

View File

@@ -0,0 +1,312 @@
# Tile Distribution: Mapping Threads to Data
## Overview
**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks.
## The Problem
Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine:
- Which threads load which elements.
- How the threads are organized into warps.
- The number of times each warp repeats its pattern.
- The number of elements each thread can load in a single vector instruction.
---
## Bottom-Up Construction Approach
### Step 1: Determine K Dimension Layout
**Start with the innermost dimension (K) for memory coalescing:**
```cpp
constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load)
constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension
```
**Example (with fp16):**
- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction
- `kKPerBlock = 32`
- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension
**Visual:**
```
K dimension (32 elements):
Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31]
K1=8 K1=8 K1=8 K1=8
├──────────────────────────────────────────────────────────────┤
K0=4 threads
```
---
### Step 2: Determine M Dimension Layout
**Now partition the M dimension hierarchically:**
#### Level 1: Threads per Warp in M (M2)
```cpp
constexpr index_t M2 = get_warp_size() / K0;
```
- Warp size = 64 threads
- K dimension already uses `K0 = 4` threads per row
- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension
**Visual (Single Warp):**
```
K dimension (4 threads)
┌─────┬─────┬─────┬─────┐
0 │ T0 │ T1 │ T2 │ T3 │
1 │ T4 │ T5 │ T6 │ T7 │
2 │ T8 │ T9 │ T10 │ T11 │
M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows
...│ ... │ ... │ ... │ ... │ (M2=16)
15 │ T60 │ T61 │ T62 │ T63 │
└─────┴─────┴─────┴─────┘
One Warp = 64 threads
```
#### Level 2: Warps per Block (M1)
```cpp
constexpr index_t M1 = kBlockSize / get_warp_size();
```
- `kBlockSize = 256` threads per block
- `M1 = 256 / 64 = 4` → We have 4 warps per block
**Visual (4 Warps):**
```
Warp 0 (rows 0-15)
Warp 1 (rows 16-31)
Warp 2 (rows 32-47)
Warp 3 (rows 48-63)
M1 = 4 warps cover 64 rows total
```
#### Level 3: Repetitions (M0)
```cpp
constexpr index_t M0 = kMPerBlock / (M2 * M1);
```
- `kMPerBlock = 256` rows to cover
- `M2 * M1 = 16 * 4 = 64` rows covered by all warps
- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times
**Visual (Complete Block):**
```
┌──────────────┐
│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ...
│ (rows 0-63) │
├──────────────┤
│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ...
│ (rows 64-127)│
├──────────────┤
│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ...
│(rows 128-191)│
├──────────────┤
│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ...
│(rows 192-255)│
└──────────────┘
M0 = 4 iterations
```
---
## The Tile Distribution Encoding
Now we can construct the distribution:
```cpp
tile_distribution_encoding<
sequence<1>, // [1] Replication
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // [2] Hierarchy
tuple<sequence<1>, sequence<1, 2>>, // [3] Parallelism:
tuple<sequence<1>, sequence<2, 0>>, // [3] Parallelism
sequence<1, 2>, // [4] Yield
sequence<0, 1> // [4] Yield
>
```
### [1] Replication: `sequence<1>`
Defines how many times warp patterns are replicated:
- `1` = Each warp has a unique pattern (no replication)
- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing
- `4` = All warps do the same thing
In our case: `1` means no replication (each warp is independent).
---
### [2] Hierarchy: The Multi-Level Structure
```cpp
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>
M dimension K dimension
```
**Concrete values:**
- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp)
- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread)
---
### [3] Parallelism: Addressing the Hierarchy
**The key insight:** Read the tuples **vertically** to understand indexing!
```cpp
tuple<sequence<1>, sequence<1, 2>>
tuple<sequence<1>, sequence<2, 0>>
```
#### Reading Pattern
**Column 1 (Dimension 0 = M):**
```
sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension)
sequence<1>
```
**Column 2 (Dimension 1 = K):**
```
sequence<1, 2>
sequence<2, 0>
```
[1,2] M2=threads/warp in M dimension
[2,0] K0=threads/warp in K dimension
---
### [4] Yield Sequences: Output Ordering
```cpp
sequence<1, 2>
sequence<0, 1>
[1,0] means M0=repetitions/warp in M dimension
[2,1] means K1=elements/thread in K dimension
```
---
## Complete Example: Thread 25 in Warp 0
Let's trace where **Thread 25** in **Warp 0** reads data:
### Thread Coordinates
- Thread ID in warp: 25
- Warp ID in block: 0
### Decompose Thread 25
```
Thread 25 in a 2D layout (M2=16, K0=4):
Row index: 25 / 4 = 6
Col index: 25 % 4 = 1
```
### M Position (Row)
```
M0 iteration: 0 (first iteration)
M1 warp: 0 (warp 0)
M2 thread: 6 (6th row in warp)
→ M position = 0*64 + 0*16 + 6 = 6
```
### K Position (Column)
```
K0 thread: 1 (column group 1)
K1 elements: 8 (will load 8 consecutive elements)
→ K position = 1*8 + [0-7] = elements 8-15
```
**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements).
---
## Why This Matters
### 1. **Memory Coalescing**
- Consecutive threads access consecutive memory → efficient global memory access
- K dimension uses K1=8 for vectorized loads
### 2. **Warp Efficiency**
- All 64 threads in a warp are utilized
- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads
### 3. **Scalability**
- M0 repetitions allow handling larger tiles
- Same pattern scales to different sizes
### 4. **Register Allocation**
- Each thread knows exactly how many elements it will hold
- Compiler can allocate registers optimally
---
## Summary Table
| Parameter | Value | Meaning |
|-----------|-------|---------|
| **K1** | 8 | Elements per thread (vector width) |
| **K0** | 4 | Threads along K per row |
| **M2** | 16 | Threads along M per warp |
| **M1** | 4 | Warps per block |
| **M0** | 4 | Repetitions of warp pattern |
| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) |
| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) |
| **Elements/Thread** | 32 | M0×K1 = 4×8 |
---
## Visualization: Complete Thread Block
```
Block Tile: 256×32
K dimension (32 elements)
├─────────────────────┤
0 ┌──────────────────────┐ ┐
16 │ Warp 0 │ │
32 │ Warp 1 │ │ Iteration 0
48 │ Warp 2 │ │ (M0=0)
64 │ Warp 3 │ ┘
80 ├──────────────────────┤ ┐
96 │ Warp 0 │ │
112 │ Warp 1 │ │ Iteration 1
128 │ Warp 2 │ │ (M0=1)
144 │ Warp 3 │ ┘
160 ├──────────────────────┤ ┐
176 │ Warp 0 │ │
192 │ Warp 1 │ │ Iteration 2
208 │ Warp 2 │ │ (M0=2)
224 │ Warp 3 │ ┘
240 ├──────────────────────┤ ┐
256 │ Warp 0 │ │
│ Warp 1 │ │ Iteration 3
│ Warp 2 │ │ (M0=3)
│ Warp 3 │ ┘
└──────────────────────┘
Each warp processes 16 rows × 32 cols = 512 elements
Each iteration processes 64 rows × 32 cols = 2048 elements
Total: 4 iterations × 2048 = 8192 elements ✓
```
---
## Key Takeaways
1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy
2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels
3. **Replication controls redundancy**: How many warps share the same pattern
4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping
This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping!

View File

@@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
tile_distribution_encoding<sequence<1>, // replication
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // hierarchy
tuple<sequence<1>, sequence<1, 2>>, // parallelism
tuple<sequence<1>, sequence<2, 0>>, // paralleism
sequence<1, 2>, // yield
sequence<0, 1>>{}); // yield
}
template <typename Problem>

View File

@@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
const BDRAMTensorView& b_dram,
CDRAMTensorView& c_dram_ref) const
CDRAMTensorView& c_dram) const
{
// Size of the entire problem

View File

@@ -6,7 +6,7 @@
#include "practice_gemm.hpp"
#include "reference_gemm.hpp"
int main()
int main(int argc, char* argv[])
{
// TODO: GemmTypeConfig
using ADataType = ck_tile::half_t;
@@ -14,11 +14,22 @@ int main()
using CDataType = float;
using AccDataType = float;
// ArgParser
ck_tile::index_t M = 512;
ck_tile::index_t N = 256;
ck_tile::index_t K = 64;
ck_tile::index_t verification = 1;
// Setup simple argument parser for M, N, K
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "64", "k dimension")
.insert("v", "1", "verification: 0=off, 1=on");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
// Get problem dimensions from command line
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t verification = arg_parser.get_int("v");
ck_tile::index_t stride_a = K;
ck_tile::index_t stride_b = K;
@@ -61,9 +72,6 @@ int main()
ck_tile::DeviceMem c_device(c_host);
// TODO: BlockTileConfig
// constexpr ck_tile::index_t warpSize = 64;
constexpr ck_tile::index_t kBlockSize = 256;
using BlockTile = ck_tile::sequence<256, 128, 32>;
using WaveTile = ck_tile::sequence<16, 16, 16>;
@@ -77,11 +85,13 @@ int main()
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
std::cout << "kGridSize: " << kGridSize << std::endl;
std::cout << "Total number of thread blocks: " << kGridSize << std::endl;
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
std::cout << "kBlockSize: " << kBlockSize << std::endl;
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
// Block size is now derived from the shape configuration
constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize;
std::cout << "Number of threads per block: " << kBlockSize << std::endl;
std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl;
using gemm_kernel =
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;

View File

@@ -24,6 +24,10 @@ struct PracticeGemmShape
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
// Thread block configuration
static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront)
static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads)
CK_TILE_HOST static std::string GetName()
{
// clang-format off
@@ -40,7 +44,8 @@ struct PracticeGemmKernel
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = 256;
// Derive block size from the shape configuration
static constexpr index_t kBlockSize = Problem::Shape::kBlockSize;
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
const typename Problem::BDataType* p_b,