mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[rocm-libraries] ROCm/rocm-libraries#7714 (commit 13ae6d6)
[CK_TILE] Restructure naive GEMM tutorial and add tile distribution tutorials (#7714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Flatten naive GEMM tutorial directory structure (remove `block_level/`, `host_level/`, `warp_level/` subdirs) to match the composable_kernel repo layout - Add `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` macro switch to toggle between standard and transposed WarpGemm variants - Consolidate 6 verbose markdown files (~2600 lines) into one concise README (~120 lines) - Add 3 tile distribution encoding tutorials with step-by-step "How to read Ps/Ys" annotations: - Tutorial 1: A-matrix DRAM load (256×32) — NDimP=2, coalesced K-splitting - Tutorial 2: B-matrix DRAM load (128×32) — same pattern, fewer iterations - Tutorial 3: C-matrix register layout (32×32) — MFMA m32n32k8 hardware output mapping, standard vs transposed - Tile distribution tutorials guarded to build only for gfx942 and gfx950
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a56c8d6017
commit
c73c50a96e
@@ -7,3 +7,4 @@ include_directories(AFTER
|
||||
|
||||
add_subdirectory(00_copy_kernel)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(tile_distribution)
|
||||
|
||||
@@ -1,589 +0,0 @@
|
||||
# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
|
||||
## Overview
|
||||
|
||||
The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates:
|
||||
1. **Data movement** from DRAM → Registers → LDS
|
||||
2. **GEMM computation** using data in LDS
|
||||
3. **Iteration** over the K dimension when needed
|
||||
|
||||
This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C.
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Problem and Policy
|
||||
|
||||
Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern:
|
||||
|
||||
### Problem: `PracticeGemmBlockPipelineProblem`
|
||||
Contains:
|
||||
- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType`
|
||||
- **Shape information**: `BlockTile` and `WaveTile` dimensions
|
||||
|
||||
### Policy: `PracticeGemmBlockPolicy`
|
||||
Contains strategies for:
|
||||
1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`)
|
||||
- Defines how 256 threads in a block map to elements of a block tile
|
||||
- Each thread knows which elements to load/store from DRAM to its registers
|
||||
- We'll cover tile distribution construction in detail later
|
||||
|
||||
2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`)
|
||||
- Describes how data is logically organized in Local Data Share (LDS)
|
||||
- Optimizes for bank conflict avoidance and efficient access patterns
|
||||
- We'll cover LDS descriptor construction in detail later
|
||||
|
||||
3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`)
|
||||
- Returns the warp-level GEMM implementation
|
||||
|
||||
---
|
||||
|
||||
## Inputs and Outputs
|
||||
|
||||
```cpp
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
```
|
||||
|
||||
### Inputs:
|
||||
- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock)
|
||||
- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock)
|
||||
- `num_loop`: Number of iterations along K dimension
|
||||
- `p_smem`: Pointer to shared memory (LDS)
|
||||
|
||||
### Output:
|
||||
- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs)
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Walkthrough
|
||||
|
||||
### Step 1: Create LDS Tensor Views
|
||||
|
||||
```cpp
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// B tile in LDS (placed after A in shared memory)
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We partition the shared memory (`p_smem`) into two regions: one for A, one for B
|
||||
- We create **tensor views** over these LDS regions using descriptors from the policy
|
||||
- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Shared Memory (LDS):
|
||||
┌─────────────────────┬─────────────────────┐
|
||||
│ A Block Tile │ B Block Tile │
|
||||
│ (256×32 fp16) │ (128×32 fp16) │
|
||||
└─────────────────────┴─────────────────────┘
|
||||
↑ ↑
|
||||
p_a_lds p_b_lds
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Create Tile Windows for Data Movement
|
||||
|
||||
We create **6 tile windows** for different purposes:
|
||||
|
||||
#### 2a. DRAM → Registers (Load from DRAM)
|
||||
|
||||
```cpp
|
||||
auto a_copy_dram_window = make_tile_window(
|
||||
a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>()); // ← Tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `a_copy_dram_window` is a `tile_window_with_static_distribution`
|
||||
- The **tile distribution** tells each thread which elements to load from DRAM
|
||||
- This window will **slide along the K dimension** in the loop
|
||||
|
||||
#### 2b. Registers → LDS (Store to LDS)
|
||||
|
||||
```cpp
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
{0, 0}, // Origin at (0, 0) in LDS
|
||||
a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Uses the **same tile distribution** as `a_copy_dram_window`
|
||||
- This ensures each thread stores to LDS in the same pattern it loaded from DRAM
|
||||
- Origin is always `{0, 0}` because LDS is reused for each K iteration
|
||||
|
||||
#### 2c. LDS → Registers (GEMM Input)
|
||||
|
||||
```cpp
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0}); // No tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- This is a `tile_window_with_static_lengths` (no explicit distribution)
|
||||
- Used as input to the warp-level GEMM
|
||||
- The warp GEMM will handle its own thread mapping internally
|
||||
|
||||
**Similar windows are created for B:**
|
||||
- `b_copy_dram_window`: Load B from DRAM
|
||||
- `b_copy_lds_window`: Store B to LDS
|
||||
- `b_lds_gemm_window`: Read B from LDS for GEMM
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Create Distributed Tensors (VGPRs)
|
||||
|
||||
```cpp
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile; // Per-thread registers for A
|
||||
BBlockTile b_block_tile; // Per-thread registers for B
|
||||
```
|
||||
|
||||
#### What is `make_static_distributed_tensor`?
|
||||
|
||||
**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**.
|
||||
|
||||
**Key Properties:**
|
||||
1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers
|
||||
2. **Compile-time sized**: Buffer size determined by tile distribution at compile time
|
||||
3. **Zero-overhead**: All indexing and layout transformations happen at compile time
|
||||
|
||||
**How it works:**
|
||||
|
||||
```cpp
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
// Calculate per-thread storage size from tile distribution
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
// Per-thread register array (VGPRs)
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
```
|
||||
|
||||
**The tile distribution defines:**
|
||||
- **Which elements each thread owns** in the tile
|
||||
- **How many elements** each thread stores (buffer size)
|
||||
- **How elements are laid out** in each thread's registers
|
||||
|
||||
**Concrete Example for 256×32 tile with 256 threads:**
|
||||
|
||||
```
|
||||
Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values)
|
||||
Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values)
|
||||
Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values)
|
||||
...
|
||||
Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values)
|
||||
```
|
||||
|
||||
**Collectively:**
|
||||
- All 256 threads together hold the **entire 256×32 tile** (8192 elements)
|
||||
- Each thread's buffer lives in its **own VGPRs**
|
||||
- No two threads own the same element
|
||||
|
||||
**Distributed Ownership Analogy:**
|
||||
Think of a tile as a **jigsaw puzzle**:
|
||||
- The **tile distribution** is the cutting pattern
|
||||
- Each **thread** gets one puzzle piece (its slice)
|
||||
- Each **`static_distributed_tensor`** is a box holding all pieces
|
||||
- Each thread's **`thread_buf_`** is its individual piece in its own registers
|
||||
|
||||
---
|
||||
|
||||
### Step 4: The GEMM Loop
|
||||
|
||||
```cpp
|
||||
// Initialize C accumulator to zero
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
index_t iCounter = num_loop; // Number of K iterations
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// 1. Load from DRAM to registers
|
||||
a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs
|
||||
b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs
|
||||
|
||||
// 2. Move windows for next iteration
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32)
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32)
|
||||
|
||||
// 3. Store from registers to LDS
|
||||
store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS
|
||||
|
||||
// 4. Synchronize threads (ensure all data is in LDS)
|
||||
block_sync_lds();
|
||||
|
||||
// 5. Compute GEMM using data in LDS
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// 6. Synchronize threads (before overwriting LDS in next iteration)
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile; // Return accumulated result in registers
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Detailed Loop Breakdown
|
||||
|
||||
### Phase 1: Load (DRAM → VGPRs)
|
||||
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution)
|
||||
2. Data is loaded into **per-thread registers** (VGPRs)
|
||||
3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once)
|
||||
|
||||
**Example for Thread 0:**
|
||||
```
|
||||
Thread 0 loads:
|
||||
A[0,0:7] (8 fp16 values, one vector load)
|
||||
A[1,0:7] (8 fp16 values, one vector load)
|
||||
...
|
||||
```
|
||||
|
||||
### Phase 2: Move Windows
|
||||
|
||||
```cpp
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example)
|
||||
- This prepares for the next K iteration
|
||||
- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ...
|
||||
|
||||
**Visualization for Problem Size 512×256×64:**
|
||||
```
|
||||
Matrix A (512×64):
|
||||
┌─────────────────────────────────────┐
|
||||
│ Block 0: rows 0-255 │
|
||||
│ ┌──────────┬──────────┐ │
|
||||
│ │ K=0:31 │ K=32:63 │ │ ← Window slides right
|
||||
│ │ Iter 0 │ Iter 1 │ │
|
||||
│ └──────────┴──────────┘ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Phase 3: Store (VGPRs → LDS)
|
||||
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread writes **its elements** from registers to LDS
|
||||
2. Uses the **same distribution** as the DRAM load
|
||||
3. Data is now in **shared memory**, accessible to all threads in the block
|
||||
|
||||
**Why this step?**
|
||||
- GEMM computation needs **all threads** to access **all data**
|
||||
- Registers are per-thread; LDS is shared across the block
|
||||
- LDS acts as a "staging area" for collaborative computation
|
||||
|
||||
### Phase 4: Synchronize
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- All threads in the block **wait** until everyone has finished storing to LDS
|
||||
- Ensures no thread starts reading from LDS before all writes are complete
|
||||
- Critical for correctness!
|
||||
|
||||
### Phase 5: GEMM Computation
|
||||
|
||||
```cpp
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. The warp-level GEMM reads data from LDS
|
||||
2. Performs matrix multiplication using MFMA instructions
|
||||
3. Accumulates results into `c_block_tile` (in registers)
|
||||
|
||||
**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results.
|
||||
|
||||
### Phase 6: Synchronize Again
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- Ensures all threads have finished reading from LDS
|
||||
- Safe to overwrite LDS in the next iteration
|
||||
|
||||
---
|
||||
|
||||
## Memory Flow Diagram
|
||||
|
||||
```
|
||||
Iteration 0 (K=0:31):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 0:31] │ │ thread) │ │ │
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (VGPRs) │
|
||||
└──────────┘
|
||||
|
||||
Iteration 1 (K=32:63):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 32:63] │ │ thread) │ │ (reused)│
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (accum.) │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example: Problem Size 512×256×64
|
||||
|
||||
### Block 0 Computation
|
||||
|
||||
**Input:**
|
||||
- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially
|
||||
- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed)
|
||||
- `num_loop`: 2 (since K=64, KPerBlock=32)
|
||||
|
||||
**Iteration 0:**
|
||||
1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63]
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T`
|
||||
|
||||
**Iteration 1:**
|
||||
1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends)
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T`
|
||||
|
||||
**Output:**
|
||||
- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. Tile Distribution
|
||||
- **Maps threads to data elements** for load/store operations
|
||||
- Each thread knows exactly which elements it's responsible for
|
||||
- Enables **parallel, vectorized** memory access
|
||||
- **Same distribution** used for DRAM load and LDS store
|
||||
|
||||
### 2. Static Distributed Tensor
|
||||
- **Per-thread register storage** (VGPRs)
|
||||
- Each thread owns a **different slice** of the tile
|
||||
- **Compile-time sized** for zero-overhead abstraction
|
||||
- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile`
|
||||
|
||||
### 3. Tile Window Movement
|
||||
- Windows **slide** over larger tensors
|
||||
- Enables iteration over the K dimension
|
||||
- `move_tile_window(window, step)` updates the origin
|
||||
|
||||
### 4. LDS as Staging Area
|
||||
- **Shared memory** accessible to all threads in a block
|
||||
- Required because GEMM needs all threads to access all data
|
||||
- **Reused** across K iterations (same LDS buffer)
|
||||
|
||||
### 5. Synchronization
|
||||
- `block_sync_lds()` ensures memory consistency
|
||||
- **Before GEMM**: All stores to LDS are complete
|
||||
- **After GEMM**: All reads from LDS are complete
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `static_distributed_tensor` Mechanics
|
||||
|
||||
### How Tile Distribution Creates Per-Thread Storage
|
||||
|
||||
When you call:
|
||||
```cpp
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<fp16_t>(ABlockTileDistr{}));
|
||||
ABlockTile a_block_tile;
|
||||
```
|
||||
|
||||
**Step 1: Extract Thread Tensor Descriptor**
|
||||
|
||||
The tile distribution contains a `ys_to_d_descriptor` that maps:
|
||||
- **Y dimensions** (logical tile coordinates, e.g., M, K)
|
||||
- **D dimension** (per-thread register index, linearized)
|
||||
|
||||
```cpp
|
||||
using ThreadTensorDesc =
|
||||
decltype(StaticTileDistribution{}.get_ys_to_d_descriptor());
|
||||
```
|
||||
|
||||
**Step 2: Calculate Per-Thread Buffer Size**
|
||||
|
||||
```cpp
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
- 256×32 tile distributed across 256 threads
|
||||
- Each thread owns 32 elements (one row)
|
||||
- `thread_buffer_size = 32` (for PackedSize=1)
|
||||
|
||||
**Step 3: Allocate Thread Buffer**
|
||||
|
||||
```cpp
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
```
|
||||
|
||||
This is essentially:
|
||||
```cpp
|
||||
fp16_t data[32]; // Per-thread register array (VGPRs)
|
||||
```
|
||||
|
||||
### Usage in Load/Store Operations
|
||||
|
||||
**Load from DRAM:**
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread queries the tile distribution: "Which elements do I own?"
|
||||
2. Thread 0 learns it owns A[0,0:31]
|
||||
3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]`
|
||||
4. All 256 threads do this **in parallel**
|
||||
|
||||
**Store to LDS:**
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread reads from its `a_block_tile.thread_buf_`
|
||||
2. Thread 0 writes A[0,0:31] from its registers to LDS
|
||||
3. All 256 threads do this **in parallel**
|
||||
4. After `block_sync_lds()`, the entire tile is in shared LDS
|
||||
|
||||
### Distributed Indexing
|
||||
|
||||
The `static_distributed_tensor` supports compile-time indexing:
|
||||
|
||||
```cpp
|
||||
// Access using distributed indices
|
||||
auto value = a_block_tile(tile_distributed_index<i, j>{});
|
||||
```
|
||||
|
||||
Internally:
|
||||
1. Convert distributed index → Y index (logical tile coordinates)
|
||||
2. Calculate buffer offset using `ThreadTensorDesc`
|
||||
3. Access `thread_buf_[offset]`
|
||||
|
||||
All of this happens **at compile time** with zero runtime overhead!
|
||||
|
||||
### Why This Design?
|
||||
|
||||
**Benefits:**
|
||||
1. **Parallel Memory Access**: All threads load/store simultaneously
|
||||
2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once)
|
||||
3. **Zero Overhead**: All indexing resolved at compile time
|
||||
4. **Type Safety**: Distribution mismatch caught at compile time
|
||||
5. **Register Pressure**: Compiler knows exact VGPR usage
|
||||
|
||||
**Trade-offs:**
|
||||
- Requires compile-time tile sizes
|
||||
- Distribution must be static
|
||||
- More complex type system
|
||||
|
||||
### Memory Hierarchy Summary
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) │
|
||||
│ Full matrices A, B, C │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ load_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Per-Thread Registers) │
|
||||
│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │
|
||||
│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │
|
||||
│ ... │
|
||||
│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │
|
||||
│ │
|
||||
│ ← static_distributed_tensor manages this distribution │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ store_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) │
|
||||
│ Entire block tile (256×32) │
|
||||
│ Accessible to all threads in block │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time.
|
||||
|
||||
|
||||
|
||||
@@ -1,618 +0,0 @@
|
||||
# Host-Level Pipeline: Orchestrating Block-Level GEMM
|
||||
|
||||
This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation.
|
||||
|
||||
## Overview
|
||||
|
||||
The host-level pipeline is responsible for:
|
||||
1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C
|
||||
2. **Block-to-tile mapping**: Assigning each thread block to a specific tile
|
||||
3. **Creating tile windows**: Establishing sliding windows over tensor views
|
||||
4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM
|
||||
5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
// 1. Calculate problem dimensions and tile coverage
|
||||
// 2. Map thread block to tile coordinates
|
||||
// 3. Create tile windows over A and B
|
||||
// 4. Call block-level pipeline to compute
|
||||
// 5. Store result to C
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Calculate Problem Dimensions and Tile Coverage
|
||||
|
||||
```cpp
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{}); // 256
|
||||
const auto NPerBlock = BlockTile::at(number<1>{}); // 128
|
||||
const auto KPerBlock = BlockTile::at(number<2>{}); // 32
|
||||
|
||||
// Number of block tiles needed to cover C matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Extract problem dimensions** from tensor descriptors:
|
||||
- `M = 512`: Rows in A and C
|
||||
- `N = 256`: Columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
|
||||
2. **Get block tile sizes** from the `BlockTile` configuration:
|
||||
- `MPerBlock = 256`: Each block processes 256 rows
|
||||
- `NPerBlock = 128`: Each block processes 128 columns
|
||||
- `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration
|
||||
|
||||
3. **Calculate tile coverage**:
|
||||
- `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction
|
||||
- `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction
|
||||
- **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**!
|
||||
|
||||
### Visual Representation:
|
||||
|
||||
```
|
||||
Matrix C (512 × 256):
|
||||
┌──────────────────────┬──────────────────────┐
|
||||
│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 0 │ Block 1 │
|
||||
│ │ │
|
||||
├──────────────────────┼──────────────────────┤
|
||||
│ Tile (1,0) │ Tile (1,1) │
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 2 │ Block 3 │
|
||||
│ │ │
|
||||
└──────────────────────┴──────────────────────┘
|
||||
↑
|
||||
num_tile_m = 2
|
||||
|
||||
Total blocks needed = 2 × 2 = 4 blocks
|
||||
|
||||
Each block computes one 256×128 tile of the output matrix C.
|
||||
```
|
||||
|
||||
### How Blocks Cover Matrices A and B:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬──────┐ ┌─────────────┬──────┐
|
||||
│ Block 0,2 │ K │ │ Block 0,1 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 0-255 │ │ │ 0-127 │ │
|
||||
├─────────────┼──────┤ ├─────────────┼──────┤
|
||||
│ Block 1,3 │ K │ │ Block 2,3 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 256-511 │ │ │ 128-255 │ │
|
||||
└─────────────┴──────┘ └─────────────┴──────┘
|
||||
256 rows 64 cols 128 rows 64 cols
|
||||
|
||||
Each block needs to iterate over K dimension (64/32 = 2 iterations)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Map Thread Block to Tile Coordinates
|
||||
|
||||
```cpp
|
||||
// Get block id (0 to total_blocks - 1)
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
// Map block id to 2D tile coordinates
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate
|
||||
const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates.
|
||||
|
||||
### The `MakeBlock2TileMap` Function:
|
||||
|
||||
```cpp
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
// Create a merge transform: (N0, M0) → linear index
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
// Convert linear block_id back to 2D coordinates
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
// Return (m_idx, n_idx) - note the swap!
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### In Our Example (2×2 Grid):
|
||||
|
||||
```cpp
|
||||
// Block 0:
|
||||
id_block = 0
|
||||
tile_id = block2tile(0) = (0, 0) // Top-left tile
|
||||
tile_id_m = 0, tile_id_n = 0
|
||||
|
||||
// Block 1:
|
||||
id_block = 1
|
||||
tile_id = block2tile(1) = (1, 0) // Bottom-left tile
|
||||
tile_id_m = 1, tile_id_n = 0
|
||||
|
||||
// Block 2:
|
||||
id_block = 2
|
||||
tile_id = block2tile(2) = (0, 1) // Top-right tile
|
||||
tile_id_m = 0, tile_id_n = 1
|
||||
|
||||
// Block 3:
|
||||
id_block = 3
|
||||
tile_id = block2tile(3) = (1, 1) // Bottom-right tile
|
||||
tile_id_m = 1, tile_id_n = 1
|
||||
```
|
||||
|
||||
**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing!
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Calculate Tile Origin and Create Tile Windows
|
||||
|
||||
```cpp
|
||||
// Calculate the starting position of this tile in the global matrix
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128
|
||||
|
||||
// Create tile windows over A and B tensor views
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, // Tensor view over A
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // Window size: 256×32
|
||||
{tile_origin_m, 0} // Origin: varies by block
|
||||
);
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, // Tensor view over B
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), // Window size: 128×32
|
||||
{tile_origin_n, 0} // Origin: varies by block
|
||||
);
|
||||
```
|
||||
|
||||
### Tile Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0 (Tile 0,0):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 1 (Tile 1,0):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 2 (Tile 0,1):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
|
||||
// Block 3 (Tile 1,1):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
```
|
||||
|
||||
### What are Tile Windows?
|
||||
|
||||
A **tile window** is a **sliding window** over a larger tensor view. It:
|
||||
- Defines a **rectangular region** within the tensor
|
||||
- Has a **fixed size** (e.g., 256×32 for A)
|
||||
- Has an **origin** (starting position)
|
||||
- Can be **moved** to access different regions
|
||||
### Visual Representation (Block 0 Example):
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │
|
||||
│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │
|
||||
│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │
|
||||
│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │
|
||||
│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │
|
||||
│ │ │ ├─────────────┼─────────────┤
|
||||
├─────────────┼─────────────┤ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
└─────────────┴─────────────┘ └─────────────┴─────────────┘
|
||||
Origin: (0, 0) Origin: (0, 0)
|
||||
Covers rows 0-255 Covers rows 0-127
|
||||
Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration)
|
||||
```
|
||||
|
||||
**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration.
|
||||
|
||||
### Tile Window Properties:
|
||||
|
||||
```cpp
|
||||
// Tile window structure (conceptual):
|
||||
struct tile_window {
|
||||
TensorView& tensor_view; // Reference to underlying tensor
|
||||
Tuple window_lengths; // Size of the window (256, 32)
|
||||
MultiIndex window_origin; // Starting position (0, 0)
|
||||
|
||||
// Can move the window:
|
||||
void move(MultiIndex step); // Shift window by step
|
||||
|
||||
// Access data through the window:
|
||||
auto load(); // Load data from windowed region
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
### Tile Window Movement: Iterating Over K Dimension
|
||||
|
||||
In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K
|
||||
│ ┃ 256×32 ┃ │ ║ 256×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ │ │
|
||||
│ Block 1's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
|
||||
Matrix B (256 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │
|
||||
│ ┃ 128×32 ┃ │ ║ 128×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ Block 2's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
```
|
||||
|
||||
### How Windows Move (Conceptual - handled by block pipeline):
|
||||
|
||||
```cpp
|
||||
// Iteration 0:
|
||||
a_block_window origin: (tile_origin_m, 0) // K columns 0-31
|
||||
b_block_window origin: (tile_origin_n, 0) // K columns 0-31
|
||||
// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31]
|
||||
|
||||
// Move windows to next K position:
|
||||
move_tile_window(a_block_window, {0, 32});
|
||||
move_tile_window(b_block_window, {0, 32});
|
||||
|
||||
// Iteration 1:
|
||||
a_block_window origin: (tile_origin_m, 32) // K columns 32-63
|
||||
b_block_window origin: (tile_origin_n, 32) // K columns 32-63
|
||||
// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63]
|
||||
|
||||
// Final result:
|
||||
// C_tile = C_partial_0 + C_partial_1
|
||||
```
|
||||
|
||||
**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C.
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Delegate to Block-Level Pipeline
|
||||
|
||||
```cpp
|
||||
// Get the block-level pipeline from policy
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
// Calculate number of K iterations needed
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2
|
||||
|
||||
// Allocate shared memory (LDS) for block-level computation
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
|
||||
// Call block-level pipeline to compute C tile
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation
|
||||
2. **Calculate K iterations**: How many times to iterate over the K dimension
|
||||
- In our example: `K=64, KPerBlock=32` → **2 iterations**
|
||||
- Each iteration processes 32 elements of the K dimension
|
||||
- Results are accumulated across iterations
|
||||
|
||||
3. **Allocate shared memory**:
|
||||
- `__shared__` declares memory shared by all threads in the block
|
||||
- `GetStaticLDSSize()` returns the required size in bytes
|
||||
- This memory is used for:
|
||||
- Staging data from DRAM → LDS
|
||||
- Cooperative loading by threads
|
||||
- Fast access during computation
|
||||
|
||||
4. **Execute block pipeline**:
|
||||
- Takes A and B tile windows as input
|
||||
- Performs the GEMM computation: `C_tile = A_tile × B_tile`
|
||||
- Returns result in `c_block_tile` (stored in VGPRs - registers)
|
||||
|
||||
### Memory Hierarchy During Computation:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) - Slowest, Largest │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A matrix │ │ B matrix │ │ C matrix │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load ↑ store
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A_tile │ │ B_tile │ ← Staged here │
|
||||
│ │ (p_smem) │ │ (p_smem) │ │
|
||||
│ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ c_block_tile (accumulated result) │ │
|
||||
│ │ Computation happens here using MFMA instructions │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Block Pipeline Responsibilities:
|
||||
|
||||
The block pipeline (called here) will:
|
||||
1. Load A and B tiles from DRAM → LDS (cooperative loading)
|
||||
2. Distribute work among warps
|
||||
3. Each warp loads its portion from LDS → VGPRs
|
||||
4. Perform MFMA operations: `C += A × B`
|
||||
5. Accumulate results in VGPRs
|
||||
6. Return final `c_block_tile` in registers
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Store Results to DRAM
|
||||
|
||||
```cpp
|
||||
// Create a tile window over C for writing results
|
||||
auto c_window = make_tile_window(
|
||||
c_dram, // Tensor view over C
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}), // Window size: 256×128
|
||||
{tile_origin_m, tile_origin_n} // Origin: varies by block
|
||||
);
|
||||
|
||||
// Store computed tile from VGPRs to DRAM
|
||||
store_tile(c_window, c_block_tile);
|
||||
```
|
||||
|
||||
### C Window Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0: Writes to top-left tile
|
||||
c_window origin: (0, 0) → writes to C[0:255, 0:127]
|
||||
|
||||
// Block 1: Writes to bottom-left tile
|
||||
c_window origin: (256, 0) → writes to C[256:511, 0:127]
|
||||
|
||||
// Block 2: Writes to top-right tile
|
||||
c_window origin: (0, 128) → writes to C[0:255, 128:255]
|
||||
|
||||
// Block 3: Writes to bottom-right tile
|
||||
c_window origin: (256, 128) → writes to C[256:511, 128:255]
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Create C tile window**:
|
||||
- Size: 256×128 (matches our block tile size)
|
||||
- Origin: Varies by block - each block writes to its assigned region
|
||||
- This window defines **where** to write the results
|
||||
|
||||
2. **Store tile to DRAM**:
|
||||
- `c_block_tile`: Computed results in VGPRs (registers)
|
||||
- `c_window`: Destination window in DRAM
|
||||
- `store_tile()`: Efficiently writes data from registers → DRAM
|
||||
|
||||
### The `store_tile` Function:
|
||||
|
||||
Recall from our earlier discussion, `store_tile` does:
|
||||
|
||||
```cpp
|
||||
template <typename TileWindow, typename DistributedTensor>
|
||||
void store_tile(TileWindow& tile_window_tmp,
|
||||
const DistributedTensor& dstr_tensor)
|
||||
{
|
||||
// 1. Extract tile distribution from distributed tensor
|
||||
using TileDstr = typename DistributedTensor::TileDistribution;
|
||||
|
||||
// 2. Upgrade simple tile window to one with distribution
|
||||
auto tile_window = make_tile_window(
|
||||
tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
TileDstr{} // Add distribution info
|
||||
);
|
||||
|
||||
// 3. Store using vectorized writes
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
```
|
||||
|
||||
### Memory Flow:
|
||||
|
||||
```
|
||||
VGPRs (Registers) DRAM (Global Memory)
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ c_block_tile │ │ C matrix │
|
||||
│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │
|
||||
│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │
|
||||
│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │
|
||||
│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │
|
||||
│ └───┴───┴───┴───┘ │ │ │ │ │
|
||||
│ Distributed across │ │ └───────────────┘ │
|
||||
│ threads/warps │ │ Origin: (0, 0) │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
|
||||
Each thread writes its portion using vector stores (e.g., float4)
|
||||
```
|
||||
|
||||
### Store Optimization:
|
||||
|
||||
The `store_tile` function:
|
||||
- Uses **vectorized stores** (write multiple elements at once)
|
||||
- Ensures **coalesced memory access** (adjacent threads write adjacent memory)
|
||||
- Respects **tile distribution** (each thread knows what data it owns)
|
||||
- Handles **out-of-bounds** checking (for partial tiles at boundaries)
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow Visualization
|
||||
|
||||
Let's trace the complete flow for **Block 0** (other blocks follow the same pattern):
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Calculate Tile Coverage │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ M=512, N=256, K=64 │ │
|
||||
│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │
|
||||
│ │ num_tile_m = ceil(512/256) = 2 │ │
|
||||
│ │ num_tile_n = ceil(256/128) = 2 │ │
|
||||
│ │ Total blocks needed = 2 × 2 = 4 blocks │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Map Block to Tile (Block 0 example) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Block ID: 0 │ │
|
||||
│ │ Tile coordinates: (0, 0) - top-left tile │ │
|
||||
│ │ Tile origin: (0, 0) │ │
|
||||
│ │ │ │
|
||||
│ │ (Blocks 1,2,3 get different tile coordinates) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Create Tile Windows │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ a_block_window: 256×32 starting at (0,0) over A │ │
|
||||
│ │ b_block_window: 128×32 starting at (0,0) over B │ │
|
||||
│ │ Windows initially cover K columns 0-31 │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Execute Block Pipeline (2 K iterations) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Allocate shared memory (LDS) │ │
|
||||
│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 0 (K=0-31): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │
|
||||
│ │ └─ Move windows: {0, 32} │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 1 (K=32-63): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │
|
||||
│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │
|
||||
│ │ │ │
|
||||
│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 5: Store Results │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Create c_window: 256×128 starting at (0,0) over C │ │
|
||||
│ │ store_tile(c_window, c_block_tile) │ │
|
||||
│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │
|
||||
│ │ │ │
|
||||
│ │ Block 0 writes to C[0:255, 0:127] │ │
|
||||
│ │ (Other blocks write to their respective regions) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
All 4 blocks execute in parallel, each computing its assigned 256×128 tile!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. **Tile Coverage**
|
||||
- Determines how many thread blocks are needed
|
||||
- Each block processes one tile of the output matrix C
|
||||
- Calculated as `ceil(dimension / tile_size)`
|
||||
|
||||
### 2. **Block-to-Tile Mapping**
|
||||
- Maps linear block ID to 2D tile coordinates
|
||||
- Uses column-major ordering for better memory coalescing
|
||||
- Each block knows which tile it's responsible for
|
||||
|
||||
### 3. **Tile Windows**
|
||||
- **Sliding windows** over larger tensor views
|
||||
- Define a rectangular region with fixed size and movable origin
|
||||
- Provide efficient, structured access to tensor data
|
||||
- Can be moved to access different regions (e.g., for K iterations)
|
||||
|
||||
### 4. **Memory Hierarchy**
|
||||
- **DRAM (Global)**: Largest, slowest - stores full matrices
|
||||
- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access
|
||||
- **VGPRs (Registers)**: Smallest, fastest - performs computation
|
||||
|
||||
### 5. **Data Flow**
|
||||
```
|
||||
DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM
|
||||
↑ ↓
|
||||
A, B matrices C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore:
|
||||
- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed
|
||||
- **Warp-level pipeline**: How warps perform MFMA operations
|
||||
- **Memory optimization**: LDS usage, bank conflicts, coalescing
|
||||
|
||||
The host level provides the **orchestration**, while the block and warp levels provide the **execution**!
|
||||
|
||||
@@ -1,464 +0,0 @@
|
||||
# PracticeGemmKernel: Understanding the Kernel Entry Point
|
||||
|
||||
This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views.
|
||||
|
||||
## Overview
|
||||
|
||||
The `PracticeGemmKernel` is a templated struct that:
|
||||
1. Takes raw device memory pointers for matrices A, B, and C
|
||||
2. Wraps them into **tensor views** - logical, structured views over physical memory
|
||||
3. Dispatches to the host-level pipeline for computation
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
// Step 1: Create tensor views over raw memory
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
// Step 2: Dispatch to host-level pipeline
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What are Tensor Views?
|
||||
|
||||
A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor.
|
||||
|
||||
### Key Components of a Tensor View:
|
||||
|
||||
1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers)
|
||||
2. **Raw Pointer**: Points to the actual data in memory
|
||||
3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A)
|
||||
4. **Strides**: How to navigate through memory to access elements
|
||||
5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction
|
||||
6. **Guaranteed Vector Stride**: The stride of those vectorizable elements
|
||||
|
||||
---
|
||||
|
||||
## The Memory Abstraction Hierarchy
|
||||
|
||||
CK Tile uses a three-layer abstraction to go from raw memory to structured tensors:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Layer 3: TENSOR VIEW │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ • Logical multi-dimensional structure │ │
|
||||
│ │ • Shape: (M, K) = (256, 32) │ │
|
||||
│ │ • Strides: (32, 1) for row-major layout │ │
|
||||
│ │ • Provides: operator[], coordinate-based access │ │
|
||||
│ │ • Knows: How to map (i,j) → linear offset │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 2: BUFFER VIEW │ │
|
||||
│ │ ┌─────────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ • Linear view of memory │ │ │
|
||||
│ │ │ • Pointer: p_data_ → device memory │ │ │
|
||||
│ │ │ • Size: Total number of elements │ │ │
|
||||
│ │ │ • Address space: global/LDS/generic │ │ │
|
||||
│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │
|
||||
│ │ └─────────────────────────────────────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 1: RAW PHYSICAL MEMORY │ │
|
||||
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │
|
||||
│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │
|
||||
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │
|
||||
│ │ ↑ │ │
|
||||
│ │ p_a (raw pointer from hipMalloc) │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `make_naive_tensor_view`
|
||||
|
||||
Let's break down the function call for matrix A:
|
||||
|
||||
```cpp
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer to device memory
|
||||
make_tuple(M, K), // Shape: (256, 32)
|
||||
make_tuple(stride_a, 1), // Strides: (32, 1) - row-major
|
||||
number<8>{}, // Guaranteed vector length
|
||||
number<1>{} // Guaranteed vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Function Signature:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
// Step 1: Create tensor descriptor (shape + stride information)
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
// Step 2: Create buffer view (pointer + size + address space)
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
// Step 3: Combine into tensor view
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Parameter Breakdown
|
||||
|
||||
### 1. **Template Parameter: `address_space_enum::global`**
|
||||
|
||||
Specifies where the memory lives:
|
||||
- `global`: GPU global memory (DRAM) - slowest but largest
|
||||
- `lds`: Local Data Share (shared memory) - fast, limited size
|
||||
- `generic`: Generic address space
|
||||
- `vgpr`: Vector General Purpose Registers - fastest, smallest
|
||||
|
||||
In our case, `global` means the data is in GPU DRAM.
|
||||
|
||||
### 2. **`p_a` - Raw Pointer**
|
||||
|
||||
The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data.
|
||||
|
||||
### 3. **`make_tuple(M, K)` - Shape/Lengths**
|
||||
|
||||
Defines the logical dimensions of the tensor:
|
||||
- For matrix A: `(256, 32)` means 256 rows, 32 columns
|
||||
- This is the **logical view**, independent of how data is physically laid out
|
||||
|
||||
### 4. **`make_tuple(stride_a, 1)` - Strides**
|
||||
|
||||
Defines how to navigate through memory:
|
||||
- **Stride for dimension 0 (rows)**: `stride_a = K = 32`
|
||||
- To move to the next row, skip 32 elements
|
||||
- **Stride for dimension 1 (columns)**: `1`
|
||||
- To move to the next column, skip 1 element
|
||||
|
||||
**Row-major layout example:**
|
||||
```
|
||||
Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...]
|
||||
↑ ↑
|
||||
Row 0 starts here Row 1 starts here (offset = 32)
|
||||
|
||||
To access element A[i][j]:
|
||||
offset = i * stride_a + j * 1
|
||||
= i * 32 + j
|
||||
```
|
||||
|
||||
### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length**
|
||||
|
||||
This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."**
|
||||
|
||||
#### Why is this important?
|
||||
|
||||
Modern GPUs can load multiple elements in one instruction (vectorized loads):
|
||||
- `float4`: Load 4 floats at once
|
||||
- `float8`: Load 8 floats at once (if supported)
|
||||
|
||||
By specifying `number<8>{}`, we're telling the system:
|
||||
- "You can safely use vector loads of up to 8 elements"
|
||||
- "The memory alignment and layout support this"
|
||||
|
||||
**Example:**
|
||||
```cpp
|
||||
// Without vectorization (slow):
|
||||
for (int j = 0; j < 8; j++) {
|
||||
data[j] = memory[offset + j]; // 8 separate loads
|
||||
}
|
||||
|
||||
// With vectorization (fast):
|
||||
float8 vec = *reinterpret_cast<float8*>(&memory[offset]); // 1 load!
|
||||
```
|
||||
|
||||
### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride**
|
||||
|
||||
This specifies the **stride between consecutive vectorizable elements** in the last dimension.
|
||||
|
||||
- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)"
|
||||
- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively
|
||||
|
||||
**Why does this matter?**
|
||||
|
||||
For efficient vectorized loads, elements must be:
|
||||
1. **Contiguous** (stride = 1) ✓
|
||||
2. **Aligned** properly in memory
|
||||
3. **Within the same cache line** (ideally)
|
||||
|
||||
If the stride were `2`, it would mean:
|
||||
```
|
||||
A[i][0] is at offset 0
|
||||
A[i][1] is at offset 2 (not 1!)
|
||||
A[i][2] is at offset 4
|
||||
```
|
||||
This would prevent efficient vectorization.
|
||||
|
||||
---
|
||||
|
||||
## What is a Buffer View?
|
||||
|
||||
A **buffer view** is the middle layer between raw memory and tensor view. It provides:
|
||||
|
||||
### Core Responsibilities:
|
||||
|
||||
1. **Memory Management**
|
||||
- Holds the raw pointer: `T* p_data_`
|
||||
- Tracks buffer size: `BufferSizeType buffer_size_`
|
||||
- Knows the address space: `global`, `lds`, etc.
|
||||
|
||||
2. **Vectorized Access**
|
||||
```cpp
|
||||
template <typename VectorType>
|
||||
CK_TILE_DEVICE VectorType get(index_t offset);
|
||||
```
|
||||
- Provides efficient vector loads/stores
|
||||
- Handles alignment requirements
|
||||
|
||||
3. **Bounds Checking** (optional)
|
||||
```cpp
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto get(index_t i, index_t linear_offset);
|
||||
```
|
||||
- Can optionally check if access is within bounds
|
||||
- Returns invalid value (default 0) for out-of-bounds access
|
||||
|
||||
4. **Address Space Awareness**
|
||||
- Uses different load/store instructions based on address space
|
||||
- Global memory: `global_load`, `global_store`
|
||||
- LDS: `ds_read`, `ds_write`
|
||||
|
||||
### Buffer View Structure:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
struct buffer_view
|
||||
{
|
||||
T* p_data_; // Raw pointer
|
||||
BufferSizeType buffer_size_; // Total elements
|
||||
remove_cvref_t<T> invalid_element_value_; // Value for OOB access
|
||||
|
||||
// Access operators
|
||||
const T& operator[](index_t i) const; // Read
|
||||
T& operator()(index_t i); // Write
|
||||
|
||||
// Vectorized access
|
||||
template <typename VectorType>
|
||||
VectorType get(index_t offset);
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Visual Example: Matrix A Memory Layout
|
||||
|
||||
Let's visualize how matrix A (256×32, fp16) is organized:
|
||||
|
||||
### Raw Physical Memory (Linear):
|
||||
```
|
||||
GPU DRAM Address Space:
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Byte 0 │
|
||||
│ ↓ │
|
||||
│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │
|
||||
│ ↑ ↑ │
|
||||
│ Row 0 (32 elements) Row 1 (32 elements) │
|
||||
│ │
|
||||
│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↑
|
||||
p_a (raw pointer)
|
||||
```
|
||||
|
||||
### Buffer View Layer:
|
||||
```
|
||||
buffer_view<address_space_enum::global, fp16_t, ...>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ p_data_ = p_a │
|
||||
│ buffer_size_ = 256 × 32 = 8,192 elements │
|
||||
│ address_space = global (DRAM) │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Linear indexing: buffer_view[i] → element at offset i │
|
||||
│ • Vectorized loads: get<float4>(offset) → load 4 fp16s at once│
|
||||
│ • Bounds checking: is offset < buffer_size_? │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Tensor View Layer:
|
||||
```
|
||||
tensor_view<buffer_view, tensor_descriptor>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Shape: (256, 32) │
|
||||
│ Strides: (32, 1) │
|
||||
│ Guaranteed vector length: 8 │
|
||||
│ Guaranteed vector stride: 1 │
|
||||
│ │
|
||||
│ Logical 2D View: │
|
||||
│ Col: 0 1 2 ... 31 │
|
||||
│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│
|
||||
│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │
|
||||
│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │
|
||||
│ ... │
|
||||
│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Multi-dimensional indexing: A[i][j] │
|
||||
│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │
|
||||
│ • Tile window creation: Extract sub-tensors │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow: Raw Memory → Tensor View
|
||||
|
||||
Let's trace the complete transformation for matrix A:
|
||||
|
||||
### Step 1: Kernel Launch (Host Side)
|
||||
```cpp
|
||||
// On host: Allocate device memory
|
||||
hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer
|
||||
|
||||
// Launch kernel
|
||||
kernel<<<grid, block>>>(p_a, p_b, p_c, M, N, K, ...);
|
||||
```
|
||||
|
||||
### Step 2: Inside Kernel (Device Side)
|
||||
```cpp
|
||||
// Receive raw pointer
|
||||
const fp16_t* p_a; // Points to GPU DRAM
|
||||
|
||||
// Step 2a: Create tensor descriptor
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(256, 32), // Shape
|
||||
make_tuple(32, 1), // Strides
|
||||
number<8>{}, // Vector length
|
||||
number<1>{} // Vector stride
|
||||
);
|
||||
// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8"
|
||||
|
||||
// Step 2b: Create buffer view
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer
|
||||
256 * 32 // Total elements
|
||||
);
|
||||
// buffer_view now wraps p_a with size and address space info
|
||||
|
||||
// Step 2c: Create tensor view
|
||||
auto a_dram = tensor_view{buffer_view, desc};
|
||||
// a_dram now provides structured, multi-dimensional access to p_a
|
||||
```
|
||||
|
||||
### Step 3: Using the Tensor View
|
||||
```cpp
|
||||
// Access element A[i][j]
|
||||
auto value = a_dram[make_tuple(i, j)];
|
||||
|
||||
// Create a tile window (sub-tensor)
|
||||
auto tile = make_tile_window(
|
||||
a_dram,
|
||||
make_tuple(16, 16), // 16×16 tile
|
||||
make_tuple(0, 0) // Starting at origin
|
||||
);
|
||||
|
||||
// Load tile into registers with vectorization
|
||||
auto tile_data = load_tile(tile); // Uses vector loads internally!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why This Abstraction?
|
||||
|
||||
### Benefits:
|
||||
|
||||
1. **Type Safety**: Can't accidentally access wrong dimensions
|
||||
2. **Performance**: Compiler knows about vectorization opportunities
|
||||
3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers)
|
||||
4. **Maintainability**: Logical structure separate from physical layout
|
||||
5. **Optimization**: Guaranteed vector properties enable aggressive optimizations
|
||||
|
||||
### Example: Without Tensor Views (Manual Indexing)
|
||||
```cpp
|
||||
// Ugly, error-prone, hard to optimize:
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j];
|
||||
// Hope the compiler vectorizes this? 🤞
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: With Tensor Views (Clean, Optimized)
|
||||
```cpp
|
||||
// Clean, safe, automatically vectorized:
|
||||
auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin);
|
||||
auto tile_data = load_tile(tile); // Vectorized loads guaranteed!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction:
|
||||
|
||||
1. **Raw Memory**: Linear array of bytes in GPU DRAM
|
||||
2. **Buffer View**: Adds size, address space, and vectorized access
|
||||
3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing
|
||||
|
||||
This abstraction enables:
|
||||
- ✅ Clean, readable code
|
||||
- ✅ Type-safe multi-dimensional access
|
||||
- ✅ Automatic vectorization
|
||||
- ✅ Flexible memory space handling
|
||||
- ✅ Efficient tile-based computation
|
||||
|
||||
The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation!
|
||||
|
||||
@@ -1,150 +1,115 @@
|
||||
# CK Tile Practice GEMM Example
|
||||
# CK Tile Naive GEMM Tutorial
|
||||
|
||||
This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system.
|
||||
A minimal GEMM (`C = A × B`) using the CK Tile API. No optimizations — just the
|
||||
core data flow through the three-level hierarchy: host → block → warp.
|
||||
|
||||
## CK Tile API Structure
|
||||
## Key Terms
|
||||
|
||||
In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**:
|
||||
| Term | What it is |
|
||||
|------|-----------|
|
||||
| **Problem** | Shape, data types, and layout of the GEMM matrices |
|
||||
| **Policy** | How data and computation are mapped to threads (tile distributions, warp configs) |
|
||||
| **Pipeline** | The loop that moves data through DRAM → VGPRs → LDS → MFMA and accumulates C |
|
||||
| **Epilogue** | Post-GEMM work (e.g. activation, scaling). Not used in this tutorial |
|
||||
|
||||
1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices
|
||||
2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads
|
||||
3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue)
|
||||
## Execution Hierarchy
|
||||
|
||||
## Overview
|
||||
|
||||
This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing:
|
||||
|
||||
- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch
|
||||
- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM
|
||||
- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory
|
||||
- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation
|
||||
|
||||
## Problem Setup and Data Flow
|
||||
|
||||
### Problem Size Configuration
|
||||
We set the problem size using the M, N and K variables:
|
||||
```cpp
|
||||
ck_tile::index_t M = 1024; // Number of rows in A and C
|
||||
ck_tile::index_t N = 512; // Number of columns in B and C
|
||||
ck_tile::index_t K = 256; // Number of columns in A, rows in B
|
||||
```
|
||||
practice_gemm.cpp ← host: parse args, allocate, launch, verify
|
||||
└─ grid_gemm.hpp ← host-level: block-to-tile mapping, create tile windows
|
||||
└─ block_gemm_pipeline_agmem_bgmem_creg.hpp
|
||||
│ ← block-level: loop over K, DRAM→VGPR→LDS, call warp GEMM
|
||||
└─ block_gemm_asmem_bsmem_creg.hpp
|
||||
← warp-level: LDS→VGPR, MFMA m32n32k8, accumulate C
|
||||
```
|
||||
|
||||
### Host Matrix Creation
|
||||
Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory:
|
||||
```cpp
|
||||
// Host tensors with proper strides
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides); // M × K
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides); // N × K
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides); // M × N
|
||||
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
// Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
a_device.ToDevice(a_host.data());
|
||||
**Data flow per K-iteration:**
|
||||
```
|
||||
A,B in DRAM ──load_tile──► VGPRs ──store_tile──► LDS ──sync──► warp GEMM (MFMA) ──► C in VGPRs
|
||||
```
|
||||
|
||||
### PracticeGemmShape Configuration
|
||||
A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile:
|
||||
After all K-iterations, C is stored back to DRAM.
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave
|
||||
```
|
||||
- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix.
|
||||
## File Guide
|
||||
|
||||
| File | Role |
|
||||
|------|------|
|
||||
| `practice_gemm.cpp` | Entry point: sizes, host tensors, kernel launch, verification |
|
||||
| `practice_gemm.hpp` | Composes `GridGemmProblem`, `BlockGemmPipelineProblem`, and `Gemm` struct |
|
||||
| `reference_gemm.hpp` | CPU reference for correctness checking |
|
||||
| `grid_gemm.hpp` | Host-level pipeline: maps `blockIdx` to tile coordinates, creates A/B/C tile windows |
|
||||
| `block_gemm_pipeline_agmem_bgmem_creg.hpp` | Block-level pipeline: K-loop, DRAM→LDS data movement |
|
||||
| `block_gemm_pipeline_agmem_bgmem_creg_policy.hpp` | Policy: A/B DRAM tile distributions, LDS descriptors |
|
||||
| `block_gemm_asmem_bsmem_creg.hpp` | Warp-level: reads A/B from LDS, runs MFMA, accumulates C |
|
||||
| `block_gemm_asmem_bsmem_creg_policy.hpp` | Policy: WarpGemm type selection (standard vs transposed C) |
|
||||
|
||||
- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile.
|
||||
- BlockTiles are further subdivided into WarpTiles.
|
||||
- WarpTiles over A and B similarly work together to calculate the WarpTile of C.
|
||||
## Tile Sizes
|
||||
|
||||
### Problem and Policy Composition
|
||||
```cpp
|
||||
// A Problem is composed from Shape and info about the data
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
From `practice_gemm.cpp` (fp16, `BlockSize=256`):
|
||||
|
||||
// A Policy is created describing data-to-thread mapping
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
| Matrix | Block tile | Description |
|
||||
|--------|-----------|-------------|
|
||||
| A | 256 × 32 | M × K per block |
|
||||
| B | 128 × 32 | N × K per block |
|
||||
| C | 256 × 128 | M × N per block (accumulated in registers) |
|
||||
|
||||
// A Kernel is then composed of Problem and Policy
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
Each block tile is further split across 4 warps (MWarp=4, NWarp=1).
|
||||
The warp-level MFMA instruction is `m32n32k8`.
|
||||
|
||||
## Tile Distributions
|
||||
|
||||
The policy files define how threads map to tile elements:
|
||||
|
||||
**A and B DRAM loads** (`block_gemm_pipeline_agmem_bgmem_creg_policy.hpp`):
|
||||
- Factor M (or N) into `M0 × M1 × M2`, K into `K0 × K1`
|
||||
- `P0 = warp_id → M1`, `P1 = lane_id → M2 × K0` (merged for coalescing)
|
||||
- `Y0 = M0` (iterations), `Y1 = K1` (vector load width = 8 for fp16)
|
||||
- See the `tile_distribution/` tutorials for worked examples with these exact shapes
|
||||
|
||||
**C register layout** (`block_gemm_asmem_bsmem_creg_policy.hpp`):
|
||||
- Determined by the WarpGemm type (MFMA instruction output mapping)
|
||||
- Standard: M-dimension in Hs[0], N-dimension in Hs[1]
|
||||
- Transposed: swaps M/N dimensions, changes which lanes hold which C elements
|
||||
|
||||
## Transposed C Distribution Switch
|
||||
|
||||
The macro `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` (default: 1) selects between
|
||||
two WarpGemm variants:
|
||||
|
||||
| Value | WarpGemm | C layout |
|
||||
|-------|----------|----------|
|
||||
| 1 (default) | `WarpGemmMfma*TransposedCDistribution` | Swapped A/B in MFMA, transposed C register layout |
|
||||
| 0 | `WarpGemmMfma*` | Standard MFMA, standard C register layout |
|
||||
|
||||
To build with the standard (non-transposed) variant, pass the define via compiler flags:
|
||||
```bash
|
||||
cmake -DCMAKE_CXX_FLAGS="-DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0" ..
|
||||
```
|
||||
|
||||
### Kernel Launch
|
||||
`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`:
|
||||
```cpp
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCU>(
|
||||
gemm_kernel{}, // Kernel composed of Problem + Policy
|
||||
kGridSize, // Grid dimensions
|
||||
kBlockSize, // Block dimensions
|
||||
0, // Dynamic shared memory
|
||||
// Kernel arguments: device buffers and problem dimensions
|
||||
a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(),
|
||||
M, N, K, stride_a, stride_b, stride_c));
|
||||
```
|
||||
|
||||
### Result Verification
|
||||
The results from the kernel are compared with results from CPU based computation function:
|
||||
```cpp
|
||||
// CPU reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
|
||||
|
||||
// Device results
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
// Verify correctness
|
||||
bool pass = ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
```
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
The main program (`practice_gemm.cpp`) is the entry point for the runtime flow:
|
||||
|
||||
```cpp
|
||||
int main()
|
||||
{
|
||||
// 1. Define data types and problem sizes
|
||||
using ADataType = ck_tile::half_t;
|
||||
ck_tile::index_t M = 2048, N = 1024, K = 512;
|
||||
|
||||
// 2. Create host tensors and initialize
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
// 3. Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
|
||||
// 4. Configure tile shapes
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// 5. Launch kernel
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
float ave_time = ck_tile::launch_kernel(/*...*/);
|
||||
|
||||
// 6. Verify results
|
||||
bool pass = verify_results(a_host, b_host, c_host);
|
||||
|
||||
// 7. Print performance metrics
|
||||
print_performance_metrics(ave_time, M, N, K);
|
||||
}
|
||||
```
|
||||
Both variants produce correct results — they differ only in how C elements are
|
||||
distributed across thread registers, which affects downstream store coalescing.
|
||||
|
||||
## Building and Running
|
||||
|
||||
```bash
|
||||
# From composable_kernel root directory
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_tutorial_naive_gemm -j
|
||||
cd <repo-root>/projects/composablekernel/build
|
||||
|
||||
# Run with sample sizes
|
||||
# Configure (first time)
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
|
||||
# Build
|
||||
make tile_tutorial_naive_gemm -j
|
||||
# or: ninja tile_tutorial_naive_gemm
|
||||
|
||||
# Run (default: M=3328, N=4096, K=4096)
|
||||
./bin/tile_tutorial_naive_gemm
|
||||
|
||||
# Custom sizes (positional args: verification M N K)
|
||||
./bin/tile_tutorial_naive_gemm 0 1024 512 256
|
||||
```
|
||||
This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework.
|
||||
|
||||
## Reference
|
||||
|
||||
- Tile distribution encoding: `include/ck_tile/core/tensor/tile_distribution_encoding.hpp`
|
||||
- MFMA warp gemm: `include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp`
|
||||
- Production GEMM pipeline: `include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp`
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
# Tile Distribution: Mapping Threads to Data
|
||||
|
||||
## Overview
|
||||
|
||||
**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks.
|
||||
|
||||
## The Problem
|
||||
|
||||
Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine:
|
||||
- Which threads load which elements.
|
||||
- How the threads are organized into warps.
|
||||
- The number of times each warp repeats its pattern.
|
||||
- The number of elements each thread can load in a single vector instruction.
|
||||
|
||||
---
|
||||
|
||||
## Bottom-Up Construction Approach
|
||||
|
||||
### Step 1: Determine K Dimension Layout
|
||||
|
||||
**Start with the innermost dimension (K) for memory coalescing:**
|
||||
|
||||
```cpp
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load)
|
||||
constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension
|
||||
```
|
||||
|
||||
**Example (with fp16):**
|
||||
- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction
|
||||
- `kKPerBlock = 32`
|
||||
- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension
|
||||
|
||||
**Visual:**
|
||||
```
|
||||
K dimension (32 elements):
|
||||
Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31]
|
||||
K1=8 K1=8 K1=8 K1=8
|
||||
├──────────────────────────────────────────────────────────────┤
|
||||
K0=4 threads
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Determine M Dimension Layout
|
||||
|
||||
**Now partition the M dimension hierarchically:**
|
||||
|
||||
#### Level 1: Threads per Warp in M (M2)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
```
|
||||
|
||||
- Warp size = 64 threads
|
||||
- K dimension already uses `K0 = 4` threads per row
|
||||
- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension
|
||||
|
||||
**Visual (Single Warp):**
|
||||
```
|
||||
K dimension (4 threads)
|
||||
┌─────┬─────┬─────┬─────┐
|
||||
0 │ T0 │ T1 │ T2 │ T3 │
|
||||
1 │ T4 │ T5 │ T6 │ T7 │
|
||||
2 │ T8 │ T9 │ T10 │ T11 │
|
||||
M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows
|
||||
...│ ... │ ... │ ... │ ... │ (M2=16)
|
||||
15 │ T60 │ T61 │ T62 │ T63 │
|
||||
└─────┴─────┴─────┴─────┘
|
||||
One Warp = 64 threads
|
||||
```
|
||||
|
||||
#### Level 2: Warps per Block (M1)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
```
|
||||
|
||||
- `kBlockSize = 256` threads per block
|
||||
- `M1 = 256 / 64 = 4` → We have 4 warps per block
|
||||
|
||||
**Visual (4 Warps):**
|
||||
```
|
||||
Warp 0 (rows 0-15)
|
||||
Warp 1 (rows 16-31)
|
||||
Warp 2 (rows 32-47)
|
||||
Warp 3 (rows 48-63)
|
||||
↑
|
||||
M1 = 4 warps cover 64 rows total
|
||||
```
|
||||
|
||||
#### Level 3: Repetitions (M0)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
```
|
||||
|
||||
- `kMPerBlock = 256` rows to cover
|
||||
- `M2 * M1 = 16 * 4 = 64` rows covered by all warps
|
||||
- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times
|
||||
|
||||
**Visual (Complete Block):**
|
||||
```
|
||||
┌──────────────┐
|
||||
│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ...
|
||||
│ (rows 0-63) │
|
||||
├──────────────┤
|
||||
│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ...
|
||||
│ (rows 64-127)│
|
||||
├──────────────┤
|
||||
│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ...
|
||||
│(rows 128-191)│
|
||||
├──────────────┤
|
||||
│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ...
|
||||
│(rows 192-255)│
|
||||
└──────────────┘
|
||||
M0 = 4 iterations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## The Tile Distribution Encoding
|
||||
|
||||
Now we can construct the distribution:
|
||||
|
||||
```cpp
|
||||
tile_distribution_encoding<
|
||||
sequence<1>, // [1] Replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // [2] Hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // [3] Parallelism:
|
||||
tuple<sequence<1>, sequence<2, 0>>, // [3] Parallelism
|
||||
sequence<1, 2>, // [4] Yield
|
||||
sequence<0, 1> // [4] Yield
|
||||
>
|
||||
```
|
||||
|
||||
### [1] Replication: `sequence<1>`
|
||||
|
||||
Defines how many times warp patterns are replicated:
|
||||
- `1` = Each warp has a unique pattern (no replication)
|
||||
- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing
|
||||
- `4` = All warps do the same thing
|
||||
|
||||
In our case: `1` means no replication (each warp is independent).
|
||||
|
||||
---
|
||||
|
||||
### [2] Hierarchy: The Multi-Level Structure
|
||||
|
||||
```cpp
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>
|
||||
└───────┬──────────┘ └──────┬────────┘
|
||||
M dimension K dimension
|
||||
```
|
||||
|
||||
**Concrete values:**
|
||||
- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp)
|
||||
- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread)
|
||||
|
||||
---
|
||||
|
||||
### [3] Parallelism: Addressing the Hierarchy
|
||||
|
||||
**The key insight:** Read the tuples **vertically** to understand indexing!
|
||||
|
||||
```cpp
|
||||
tuple<sequence<1>, sequence<1, 2>>
|
||||
tuple<sequence<1>, sequence<2, 0>>
|
||||
```
|
||||
|
||||
#### Reading Pattern
|
||||
|
||||
**Column 1 (Dimension 0 = M):**
|
||||
```
|
||||
sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension)
|
||||
sequence<1>
|
||||
```
|
||||
|
||||
**Column 2 (Dimension 1 = K):**
|
||||
```
|
||||
sequence<1, 2>
|
||||
sequence<2, 0>
|
||||
```
|
||||
[1,2] M2=threads/warp in M dimension
|
||||
[2,0] K0=threads/warp in K dimension
|
||||
|
||||
---
|
||||
|
||||
### [4] Yield Sequences: Output Ordering
|
||||
|
||||
```cpp
|
||||
sequence<1, 2>
|
||||
sequence<0, 1>
|
||||
|
||||
[1,0] means M0=repetitions/warp in M dimension
|
||||
[2,1] means K1=elements/thread in K dimension
|
||||
```
|
||||
---
|
||||
|
||||
## Complete Example: Thread 25 in Warp 0
|
||||
|
||||
Let's trace where **Thread 25** in **Warp 0** reads data:
|
||||
|
||||
### Thread Coordinates
|
||||
- Thread ID in warp: 25
|
||||
- Warp ID in block: 0
|
||||
|
||||
### Decompose Thread 25
|
||||
```
|
||||
Thread 25 in a 2D layout (M2=16, K0=4):
|
||||
Row index: 25 / 4 = 6
|
||||
Col index: 25 % 4 = 1
|
||||
```
|
||||
|
||||
### M Position (Row)
|
||||
```
|
||||
M0 iteration: 0 (first iteration)
|
||||
M1 warp: 0 (warp 0)
|
||||
M2 thread: 6 (6th row in warp)
|
||||
→ M position = 0*64 + 0*16 + 6 = 6
|
||||
```
|
||||
|
||||
### K Position (Column)
|
||||
```
|
||||
K0 thread: 1 (column group 1)
|
||||
K1 elements: 8 (will load 8 consecutive elements)
|
||||
→ K position = 1*8 + [0-7] = elements 8-15
|
||||
```
|
||||
|
||||
**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements).
|
||||
|
||||
---
|
||||
|
||||
## Why This Matters
|
||||
|
||||
### 1. **Memory Coalescing**
|
||||
- Consecutive threads access consecutive memory → efficient global memory access
|
||||
- K dimension uses K1=8 for vectorized loads
|
||||
|
||||
### 2. **Warp Efficiency**
|
||||
- All 64 threads in a warp are utilized
|
||||
- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads
|
||||
|
||||
### 3. **Scalability**
|
||||
- M0 repetitions allow handling larger tiles
|
||||
- Same pattern scales to different sizes
|
||||
|
||||
### 4. **Register Allocation**
|
||||
- Each thread knows exactly how many elements it will hold
|
||||
- Compiler can allocate registers optimally
|
||||
|
||||
---
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Parameter | Value | Meaning |
|
||||
|-----------|-------|---------|
|
||||
| **K1** | 8 | Elements per thread (vector width) |
|
||||
| **K0** | 4 | Threads along K per row |
|
||||
| **M2** | 16 | Threads along M per warp |
|
||||
| **M1** | 4 | Warps per block |
|
||||
| **M0** | 4 | Repetitions of warp pattern |
|
||||
| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) |
|
||||
| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) |
|
||||
| **Elements/Thread** | 32 | M0×K1 = 4×8 |
|
||||
|
||||
---
|
||||
|
||||
## Visualization: Complete Thread Block
|
||||
|
||||
```
|
||||
Block Tile: 256×32
|
||||
|
||||
K dimension (32 elements)
|
||||
├─────────────────────┤
|
||||
0 ┌──────────────────────┐ ┐
|
||||
16 │ Warp 0 │ │
|
||||
32 │ Warp 1 │ │ Iteration 0
|
||||
48 │ Warp 2 │ │ (M0=0)
|
||||
64 │ Warp 3 │ ┘
|
||||
80 ├──────────────────────┤ ┐
|
||||
96 │ Warp 0 │ │
|
||||
112 │ Warp 1 │ │ Iteration 1
|
||||
128 │ Warp 2 │ │ (M0=1)
|
||||
144 │ Warp 3 │ ┘
|
||||
160 ├──────────────────────┤ ┐
|
||||
176 │ Warp 0 │ │
|
||||
192 │ Warp 1 │ │ Iteration 2
|
||||
208 │ Warp 2 │ │ (M0=2)
|
||||
224 │ Warp 3 │ ┘
|
||||
240 ├──────────────────────┤ ┐
|
||||
256 │ Warp 0 │ │
|
||||
│ Warp 1 │ │ Iteration 3
|
||||
│ Warp 2 │ │ (M0=3)
|
||||
│ Warp 3 │ ┘
|
||||
└──────────────────────┘
|
||||
|
||||
Each warp processes 16 rows × 32 cols = 512 elements
|
||||
Each iteration processes 64 rows × 32 cols = 2048 elements
|
||||
Total: 4 iterations × 2048 = 8192 elements ✓
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy
|
||||
2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels
|
||||
3. **Replication controls redundancy**: How many warps share the same pattern
|
||||
4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping
|
||||
|
||||
This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping!
|
||||
|
||||
@@ -1,506 +0,0 @@
|
||||
# Practice GEMM: Step-by-Step Code Walkthrough
|
||||
|
||||
This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API.
|
||||
|
||||
## Overview
|
||||
|
||||
We'll implement `C = A × B` where:
|
||||
- `A` is an `M × K` matrix
|
||||
- `B` is an `N × K` matrix (note: transposed layout)
|
||||
- `C` is an `M × N` matrix
|
||||
|
||||
The implementation uses a hierarchical tiling strategy with two levels:
|
||||
1. **Block Tiles**: Processed by thread blocks
|
||||
2. **Wave Tiles**: Processed by warps (wavefronts) within blocks
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Define Data Types
|
||||
|
||||
```cpp
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We use `half_t` (FP16) for input matrices A and B.
|
||||
- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy
|
||||
- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity
|
||||
---
|
||||
|
||||
## Step 2: Define Problem Size
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `M = 512`: Number of rows in A and C
|
||||
- `N = 256`: Number of columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
- Strides define memory layout (row-major for A and C, transposed for B)
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Matrix A (M×K): Matrix B (N×K): Matrix C (M×N):
|
||||
[512 rows] [256 rows] [512 rows]
|
||||
[64 cols] [64 cols] [256 cols]
|
||||
stride = K stride = K stride = N
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Create Host Tensors
|
||||
|
||||
```cpp
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We create three tensors on the host (CPU) memory
|
||||
- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`)
|
||||
- `HostTensor` is a CK Tile utility class that manages CPU memory
|
||||
|
||||
**Stride explanation:**
|
||||
- For A: `stride_a = K` means moving to the next row requires skipping K elements
|
||||
- For B: `stride_b = K` means B is stored in transposed format
|
||||
- For C: `stride_c = N` means row-major layout
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Initialize Tensors with Random Data
|
||||
|
||||
```cpp
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- A and B are filled with random values in the range [-5.0, 5.0]
|
||||
- C is initialized to zero (will store the output)
|
||||
|
||||
**Optional: Print Tensor Contents**
|
||||
```cpp
|
||||
// Commented out in the code, but available for debugging:
|
||||
// a_host.print_first_n(10); // Print first 10 elements of A
|
||||
```
|
||||
|
||||
The `print_first_n()` helper function can display tensor contents for debugging purposes.
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Allocate Device Memory and Transfer Data
|
||||
|
||||
```cpp
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `DeviceMem` allocates GPU memory matching the size of host tensors
|
||||
- The constructor **automatically transfers data from host to device**
|
||||
- This is a convenience wrapper around `hipMalloc` and `hipMemcpy`
|
||||
|
||||
**Memory Flow:**
|
||||
```
|
||||
CPU (Host) GPU (Device)
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
│ c_host │ ────────> │c_device │
|
||||
└─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 6: Configure Hierarchical Tiling
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We define a two-level tiling hierarchy for the GEMM computation
|
||||
|
||||
### Block Tile (256 × 128 × 32)
|
||||
- **256**: M dimension per block (rows of A and C)
|
||||
- **128**: N dimension per block (columns of B and C)
|
||||
- **32**: K dimension per block (inner dimension)
|
||||
- Each block tile is processed by one **thread block** (256 threads)
|
||||
|
||||
### Wave Tile (16 × 16 × 16)
|
||||
- **16 × 16**: Output tile dimensions (M × N) per warp iteration
|
||||
- **16**: K dimension per warp iteration
|
||||
- Each wave tile is processed by one **warp** (64 threads on AMD GPUs)
|
||||
|
||||
**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile
|
||||
|
||||
**Important Note:**
|
||||
In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem.
|
||||
|
||||
### Tiling Visualization:
|
||||
|
||||
#### Matrix A (M × K = 256 × 32):
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles (16×16) │
|
||||
│ │ 16│ 16 │ in M×K space │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ │
|
||||
│ │ .. │ .. │ 16 tiles in M │
|
||||
│ ├────┼────┤ 2 tiles in K │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix B (N × K = 128 × 32):
|
||||
```
|
||||
┌──────────────────────────────┐
|
||||
│ One Block Tile (128 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles │
|
||||
│ │ 16│ 16 │ (16×16) │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ 8 tiles in N │
|
||||
│ │ .. │ .. │ 2 tiles in K │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix C (M × N = 256 × 128) - Output:
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 128) │
|
||||
│ │
|
||||
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
|
||||
│ │16× │ │ │ │ │ │ │ │ │
|
||||
│ │ 16 │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
|
||||
│ │
|
||||
│ 16 wave tiles in M direction │
|
||||
│ 8 wave tiles in N direction │
|
||||
│ Total: 128 wave tiles (16×16 each) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### How Wave Tiles Combine (C = A × B):
|
||||
```
|
||||
Matrix A Matrix B (stored transposed N×K) Matrix C
|
||||
(256×32) (128×32) (256×128)
|
||||
|
||||
Row of A tiles: Row of B tiles: One wave tile in C:
|
||||
┌────┬────┐ ┌────┬────┐ ┌────┐
|
||||
│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16)
|
||||
└────┴────┘ └────┴────┘ └────┘
|
||||
16×16 each 16×16 each
|
||||
|
||||
Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ
|
||||
↑ ↑
|
||||
K=0..15 K=16..31
|
||||
|
||||
Each wave tile in C is computed by:
|
||||
- Taking one row of wave tiles from A (2 tiles along K)
|
||||
- Taking one row of wave tiles from B (2 tiles along K)
|
||||
Note: B is stored transposed (N×K), so a "row" in storage corresponds
|
||||
to a "column" in the logical B^T matrix used in computation
|
||||
- Performing dot product: Σ(A_k × B_k^T) for k=0,1
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B
|
||||
- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory
|
||||
- This is the fundamental operation repeated across all 128 wave tiles in C
|
||||
- Each warp computes one wave tile using MFMA instructions
|
||||
|
||||
---
|
||||
|
||||
## Step 7: Create Shape, Problem, and Policy Structs
|
||||
|
||||
```cpp
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. **Shape Struct**
|
||||
Encapsulates all tile shape information (BlockTile and WaveTile dimensions).
|
||||
|
||||
### 2. **Problem Struct**
|
||||
Holds complete problem description:
|
||||
- Data types (ADataType, BDataType, CDataType, AccDataType)
|
||||
- Shape information (BlockTile, WaveTile)
|
||||
|
||||
In more complex examples, this would also include:
|
||||
- Data layouts (row-major, column-major)
|
||||
- Mathematical operations (e.g., transposed GEMM)
|
||||
|
||||
### 3. **Policy Struct**
|
||||
Describes data movement and thread-to-data mapping:
|
||||
- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions
|
||||
- In more complex kernels, includes:
|
||||
- DRAM access patterns
|
||||
- LDS (Local Data Share) usage strategies
|
||||
- Thread distribution within blocks
|
||||
|
||||
**CK Tile Design Pattern:**
|
||||
```
|
||||
Kernel = Problem + Policy + Epilogue
|
||||
↑ ↑ ↑
|
||||
(What) (How) (Post-processing)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 8: Calculate Grid and Block Dimensions
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### Grid Size Calculation
|
||||
```cpp
|
||||
kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N)
|
||||
= ceil(512 / 256) × ceil(256 / 128)
|
||||
= 2 × 2
|
||||
= 4 thread blocks
|
||||
```
|
||||
|
||||
Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction).
|
||||
|
||||
### Block Configuration
|
||||
- `kBlockSize = 256`: Each thread block has 256 threads
|
||||
- 256 threads / 64 threads per warp = **4 warps per block**
|
||||
- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity)
|
||||
|
||||
**Thread Hierarchy:**
|
||||
```
|
||||
GPU
|
||||
└── 1 Thread Block (Grid)
|
||||
└── 256 Threads
|
||||
├── Warp 0 (threads 0-63)
|
||||
├── Warp 1 (threads 64-127)
|
||||
├── Warp 2 (threads 128-191)
|
||||
└── Warp 3 (threads 192-255)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 9: Create and Launch the Kernel
|
||||
|
||||
```cpp
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. Kernel Composition
|
||||
```cpp
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
```
|
||||
The kernel is composed from Problem and Policy structs, following the CK Tile design pattern.
|
||||
|
||||
### 2. Kernel Launch
|
||||
`launch_kernel()` is a CK Tile utility that:
|
||||
- Launches the GPU kernel using HIP runtime
|
||||
- Measures execution time
|
||||
- Returns average execution time in milliseconds
|
||||
|
||||
### 3. Launch Parameters
|
||||
- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled
|
||||
- **Grid size**: `kGridSize = 1` - number of thread blocks
|
||||
- **Block size**: `kBlockSize = 256` - threads per block
|
||||
- **Shared memory**: `0` - no dynamic shared memory in this example
|
||||
- **Kernel arguments**: Device pointers and problem dimensions
|
||||
|
||||
### 4. Kernel Execution Flow
|
||||
```
|
||||
launch_kernel() calls gemm_kernel.operator()()
|
||||
↓
|
||||
PracticeGemmKernel::operator()
|
||||
↓
|
||||
Creates tensor views over device memory
|
||||
↓
|
||||
Calls block-level pipeline
|
||||
↓
|
||||
Block pipeline calls warp-level pipeline
|
||||
↓
|
||||
Warp pipeline calls MFMA instructions
|
||||
↓
|
||||
Results written back to C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 10: Verify Results
|
||||
|
||||
```cpp
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// Reference gemm on CPU
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
|
||||
// Copy GPU results back to host
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
|
||||
// Compare results
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. CPU Reference Implementation
|
||||
```cpp
|
||||
reference_basic_gemm<...>(a_host, b_host, c_host_ref);
|
||||
```
|
||||
Computes GEMM on CPU using a simple nested loop implementation (ground truth).
|
||||
|
||||
### 2. Copy GPU Results to Host
|
||||
```cpp
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
```
|
||||
Transfers the computed result from GPU memory back to CPU for comparison.
|
||||
|
||||
### 3. Error Checking
|
||||
```cpp
|
||||
ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
```
|
||||
Compares GPU and CPU results element-wise with tolerance:
|
||||
- **Relative error**: 1e-3 (0.1%)
|
||||
- **Absolute error**: 1e-3
|
||||
|
||||
**Verification Flow:**
|
||||
```
|
||||
CPU GPU
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
└─────────┘ └─────────┘
|
||||
│ │
|
||||
↓ ↓
|
||||
reference_gemm() GPU kernel
|
||||
│ │
|
||||
↓ ↓
|
||||
┌──────────┐ ┌──────────┐
|
||||
│c_host_ref│ │c_device │
|
||||
└──────────┘ └──────────┘
|
||||
│ │
|
||||
│ ↓
|
||||
│ FromDevice()
|
||||
│ │
|
||||
↓ ↓
|
||||
└────> check_err() <───┘
|
||||
│
|
||||
↓
|
||||
Pass/Fail
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Execution Flow Summary
|
||||
|
||||
```
|
||||
1. Define data types (FP16 inputs, FP32 output)
|
||||
↓
|
||||
2. Set problem size (M=256, N=128, K=32)
|
||||
↓
|
||||
3. Create host tensors and initialize with random data
|
||||
↓
|
||||
4. Allocate device memory and transfer data (CPU → GPU)
|
||||
↓
|
||||
5. Configure hierarchical tiling (BlockTile, WaveTile)
|
||||
↓
|
||||
6. Create Shape, Problem, and Policy structs
|
||||
↓
|
||||
7. Calculate grid/block dimensions (1 block, 256 threads)
|
||||
↓
|
||||
8. Compose and launch kernel (Problem + Policy)
|
||||
↓
|
||||
9. Execute GEMM on GPU
|
||||
│ ├─ Block-level pipeline
|
||||
│ ├─ Warp-level pipeline
|
||||
│ └─ MFMA instructions
|
||||
↓
|
||||
10. Verify results (compare GPU vs CPU reference)
|
||||
↓
|
||||
11. Calculate and print performance metrics
|
||||
↓
|
||||
12. Return success/failure
|
||||
```
|
||||
|
||||
---
|
||||
@@ -6,6 +6,14 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
// Controls whether to use the A/B-swapped MFMA variant with transposed C register layout.
|
||||
// 0 = WarpGemmMfmaF16F16F32M32N32K8 (standard, no swap, no transposed C)
|
||||
// 1 = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution (swap A/B in MFMA + transposed C
|
||||
// layout)
|
||||
#ifndef CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
#define CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION 1
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
@@ -15,24 +23,31 @@ struct BlockGemmASmemBSmemCRegPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// NAIVE_IMPLEMENTATION uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// NAIVE_IMPLEMENTATION uses mfma m32 n32 k8
|
||||
// mfma m32 n32 k8
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, kMWarp, kNWarp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8{}, kMWarp, kNWarp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../warp_level/block_gemm_asmem_bsmem_creg.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
@@ -8,8 +8,8 @@
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/grid_gemm.hpp"
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
21
tutorial/ck_tile/tile_distribution/CMakeLists.txt
Normal file
21
tutorial/ck_tile/tile_distribution/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# These tutorials are hard-coded for CDNA (warp_size=64) with specific tile sizes.
|
||||
# Only build for gfx942 (MI300X) and gfx950 (MI350X).
|
||||
if(NOT (GPU_TARGETS MATCHES "gfx942|gfx950"))
|
||||
message(VERBOSE "Skipping tile_distribution tutorials: requires gfx942 or gfx950")
|
||||
return()
|
||||
endif()
|
||||
|
||||
foreach(i 1 2 3)
|
||||
set(TUTORIAL_NAME "tile_tutorial_tile_distribution_${i}")
|
||||
|
||||
add_executable(${TUTORIAL_NAME} EXCLUDE_FROM_ALL tile_distribution_${i}.cpp)
|
||||
target_include_directories(${TUTORIAL_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${TUTORIAL_NAME} PRIVATE
|
||||
-Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported
|
||||
)
|
||||
|
||||
add_dependencies(tutorials ${TUTORIAL_NAME})
|
||||
endforeach()
|
||||
63
tutorial/ck_tile/tile_distribution/README.md
Normal file
63
tutorial/ck_tile/tile_distribution/README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# CK Tile Distribution Encoding Tutorial
|
||||
|
||||
## Overview
|
||||
|
||||
Every `load_tile` and `store_tile` in CK needs to know **which thread reads which data element**.
|
||||
This mapping is defined by a `tile_distribution_encoding` — a compile-time struct with 6 template
|
||||
parameters:
|
||||
|
||||
```cpp
|
||||
tile_distribution_encoding<Rs, Hs, Ps_major, Ps_minor, Ys_major, Ys_minor>
|
||||
```
|
||||
|
||||
Every level of **Hs** (hierarchical dimensions) is assigned to exactly one role:
|
||||
|
||||
| Role | Meaning |
|
||||
|------|---------|
|
||||
| **P** (parallel) | Thread ID selects which slice — different threads get different data |
|
||||
| **Y** (yield) | Each thread owns the entire range in its buffer |
|
||||
| **R** (replicate) | Identical data broadcast to multiple thread groups |
|
||||
|
||||
## Tutorials
|
||||
|
||||
These tutorials use the exact tile sizes from the naive GEMM tutorial
|
||||
(`01_naive_gemm/`): MPerBlock=256, NPerBlock=128, KPerBlock=32, BlockSize=256, fp16.
|
||||
|
||||
| # | File | Matrix | Tile | Key Concept |
|
||||
|---|------|--------|------|-------------|
|
||||
| 1 | `tile_distribution_1.cpp` | A (DRAM load) | 256×32 | NDimP=2, warp\_id→M1, lane\_id→M2×K0 (coalesced) |
|
||||
| 2 | `tile_distribution_2.cpp` | B (DRAM load) | 128×32 | Same pattern as A, but N0=2 iterations (vs A's M0=4) due to smaller N |
|
||||
| 3 | `tile_distribution_3.cpp` | C (registers) | 256×128 | Warp-level MFMA output + block-level composition, standard vs transposed |
|
||||
|
||||
Tutorial 3 responds to `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` — rebuild with `=0` or `=1`
|
||||
to see both C register layouts.
|
||||
|
||||
**Architecture note:** All comments and concrete values assume **CDNA (warp_size=64)**.
|
||||
On RDNA (warp_size=32), the thread-to-data mapping will differ.
|
||||
|
||||
## Building
|
||||
|
||||
```bash
|
||||
cd <repo-root>/projects/composablekernel/build
|
||||
|
||||
# Build all tutorials:
|
||||
make tutorials -j
|
||||
# or: ninja tutorials
|
||||
|
||||
# Or build individually:
|
||||
make tile_tutorial_tile_distribution_1 -j
|
||||
make tile_tutorial_tile_distribution_2 -j
|
||||
make tile_tutorial_tile_distribution_3 -j
|
||||
|
||||
# Tutorial 3 with standard (non-transposed) C:
|
||||
cmake -DCMAKE_CXX_FLAGS="-DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0" ..
|
||||
make tile_tutorial_tile_distribution_3 -j
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
- Encoding definition: `include/ck_tile/core/tensor/tile_distribution_encoding.hpp`
|
||||
- Thread identity (NDimP): `include/ck_tile/core/tensor/tile_distribution.hpp`
|
||||
- MFMA warp output layout: `include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp`
|
||||
- Production A/B distributions: `include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp`
|
||||
- Naive GEMM tutorial: `tutorial/ck_tile/gemm/01_naive_gemm/`
|
||||
285
tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp
Normal file
285
tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp
Normal file
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — A Matrix DRAM Load
|
||||
*
|
||||
* Demonstrates how tile_distribution_encoding maps threads to A-matrix
|
||||
* elements during a DRAM load in the naive GEMM tutorial.
|
||||
*
|
||||
* Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp
|
||||
* MakeADramTileDistribution(), with fp16, BlockSize=256
|
||||
*
|
||||
* Tile: M=256 × K=32 (matches the naive GEMM's A block tile)
|
||||
* Threads: 256 (4 warps on CDNA, 8 on RDNA)
|
||||
*
|
||||
* Host initialises A with sequential values 0, 1, 2, ... (row-major).
|
||||
* A[m][k] = m * K + k, so the printed value directly gives the linear index.
|
||||
* GPU kernel loads A using the distribution, then prints per-thread buffer
|
||||
* contents so the reader can verify which elements each thread received.
|
||||
*
|
||||
* Note: int32_t is used instead of fp16 for readable printf output.
|
||||
* The distribution encoding is hardcoded to match the fp16 derivation
|
||||
* (K1=16/sizeof(fp16)=8), not recomputed from sizeof(int32_t).
|
||||
*
|
||||
* No compute is performed — this is purely about data movement.
|
||||
*
|
||||
* Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32),
|
||||
* the thread-to-data mapping will differ.
|
||||
*/
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <cstdio>
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
// ============================================================================
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// Matrix A: M=256 rows × K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// Load the entire tile into registers using 256 threads (4 warps on CDNA).
|
||||
//
|
||||
// For coalesced memory access with fp16, each lane loads 8 contiguous
|
||||
// K-values (8 × 2 bytes = 16 bytes = 128 bits). Since K=32, we need
|
||||
// 32/8 = 4 lanes to cover one row:
|
||||
//
|
||||
// lane 0: K=0..7 lane 1: K=8..15 lane 2: K=16..23 lane 3: K=24..31
|
||||
// └──────────────── one row of 32 K-columns ──────────────────────────────┘
|
||||
//
|
||||
// With warp_size=64, each warp has 64 lanes. 4 lanes per row means
|
||||
// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4×16 = 64 rows.
|
||||
// To cover all 256 rows, each thread iterates M0 = 256/64 = 4 times.
|
||||
//
|
||||
// Per-thread buffer = 4 iterations × 8 K-values = 32 elements.
|
||||
//
|
||||
// Visually for warp 0 (lanes 0–63):
|
||||
//
|
||||
// A matrix (256×32) lane_id decomposition
|
||||
// ──────────────── ──────────────────────
|
||||
// row 0: [ K=0..7 | 8..15 | 16..23 | 24..31 ]
|
||||
// L0 L1 L2 L3 ← iter 0
|
||||
// row 1: [ K=0..7 | 8..15 | 16..23 | 24..31 ]
|
||||
// L4 L5 L6 L7
|
||||
// ...
|
||||
// row 15: same pattern, lanes 60–63
|
||||
// ────── stride of 64 rows (4 warps × 16 rows/warp) ──────
|
||||
// row 64: L0..L3 ← iter 1
|
||||
// ...
|
||||
// row 128: L0..L3 ← iter 2
|
||||
// ...
|
||||
// row 192: L0..L3 ← iter 3
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: tile_distribution_encoding
|
||||
// ============================================================================
|
||||
//
|
||||
// Production code derives (fp16, BlockSize=256, MPerBlock=256, KPerBlock=32):
|
||||
// K1 = 16/sizeof(fp16) = 8 → vector load width (8 values)
|
||||
// K0 = KPerBlock/K1 = 4 → 4 K-chunks per row
|
||||
// M2 = warp_size/K0 = 16 → 16 rows per warp
|
||||
// M1 = BlockSize/warp_size = 4 → 4 warps
|
||||
// M0 = MPerBlock/(M2*M1) = 4 → 4 iterations
|
||||
//
|
||||
// Step 1 — Hierarchical dimensions (Hs): factor each axis.
|
||||
//
|
||||
// Hs[0] = sequence<4, 4, 16> → M = 4 × 4 × 16 = 256
|
||||
// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// level 0 level 1 level 2 level 0 level 1
|
||||
// = 4 = 4 = 16 = 4 = 8
|
||||
//
|
||||
// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
//
|
||||
// P0 = warp_id → Hs[0][1] = 4 (which warp → which M-group)
|
||||
// P1 = lane_id → Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64)
|
||||
//
|
||||
// The merge transform decomposes lane_id:
|
||||
// row_in_warp = lane_id / 4 (0..15, outer)
|
||||
// k_chunk = lane_id % 4 (0..3, inner → coalesced!)
|
||||
//
|
||||
// Ps_major = tuple<sequence<1>, sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1>, sequence<2, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 2 elements → NDimP=2.
|
||||
// First element = P0 = warp_id
|
||||
// Second element = P1 = lane_id
|
||||
//
|
||||
// Ps_major = tuple< seq<1>, seq<1, 2> >
|
||||
// ─P0(warp)─ ─P1(lane)──
|
||||
// Ps_minor = tuple< seq<1>, seq<2, 0> >
|
||||
// ─P0(warp)─ ─P1(lane)──
|
||||
//
|
||||
// P0: major=<1>, minor=<1> → Hs[0], level 1 → M1=4
|
||||
// P1: major=<1,2>, minor=<2,0> → merged:
|
||||
// Hs[0] level 2 → M2=16 (outer, changes slowly)
|
||||
// Hs[1] level 0 → K0=4 (inner, changes every lane → coalesced!)
|
||||
// Total: 16 × 4 = 64 = warp_size
|
||||
// lane / 4 → row_in_warp (M2), lane % 4 → K-chunk (K0)
|
||||
//
|
||||
// Step 3 — Yield dimensions (Ys): what each thread owns.
|
||||
//
|
||||
// Y0 = Hs[0][0] = 4 (M-iterations)
|
||||
// Y1 = Hs[1][1] = 8 (vector load width)
|
||||
//
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 1>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1
|
||||
// ─Y0─ ─Y1─
|
||||
//
|
||||
// Y0: Hs[0] level 0 → M0=4 (iterations along M)
|
||||
// Y1: Hs[1] level 1 → K1=8 (vector load width)
|
||||
// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread.
|
||||
//
|
||||
// Step 4 — Replicate: Rs = sequence<1> (trivial, size 1).
|
||||
//
|
||||
// Complete tree:
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// [Y0] [P0] [P1] [P1] [Y1]
|
||||
// = 4 = 4 = 16 = 4 = 8
|
||||
// (iter) (warp) (row) (K-chunk) (vec load)
|
||||
//
|
||||
// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
static constexpr index_t kM = 256;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
struct TileDistKernelA
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const int32_t* p_data) const
|
||||
{
|
||||
static_assert(get_warp_size() == 64,
|
||||
"This tutorial is hard-coded for CDNA (warp_size=64). "
|
||||
"On RDNA (warp_size=32), the encoding values and print logic must change.");
|
||||
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data, make_tuple(kM, kK), make_tuple(kK, 1), number<1>{}, number<1>{});
|
||||
|
||||
constexpr auto distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<4, 4, 16>, sequence<4, 8>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
|
||||
auto window = make_tile_window(
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, distribution);
|
||||
|
||||
const auto tile = load_tile(window);
|
||||
|
||||
const auto& buf = tile.get_thread_buffer();
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t kBufSize = 32; // 4 iterations × 8 K-values
|
||||
|
||||
int32_t local_buf[kBufSize];
|
||||
static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast<int32_t>(buf[i]); });
|
||||
|
||||
auto print_thread = [&](int tid) {
|
||||
if(static_cast<int>(threadIdx.x) == tid)
|
||||
{
|
||||
int lane = tid % static_cast<int>(warp_size);
|
||||
int warp = tid / static_cast<int>(warp_size);
|
||||
int row_in_wrp = lane / 4;
|
||||
int k_chunk = lane % 4;
|
||||
|
||||
printf("Thread %3d (warp %d, lane %2d) row_in_warp=%2d k_chunk=%d\n",
|
||||
tid,
|
||||
warp,
|
||||
lane,
|
||||
row_in_wrp,
|
||||
k_chunk);
|
||||
|
||||
for(int iter = 0; iter < 4; iter++)
|
||||
{
|
||||
int row = iter * 64 + warp * 16 + row_in_wrp;
|
||||
int col = k_chunk * 8;
|
||||
printf(" iter %d: A[%3d][%2d..%2d] =", iter, row, col, col + 7);
|
||||
for(int k = 0; k < 8; k++)
|
||||
printf(" %5d", local_buf[iter * 8 + k]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("\n=== Tile Distribution: A-Matrix DRAM Load ===\n");
|
||||
printf("Source: MakeADramTileDistribution (fp16, BlockSize=256)\n");
|
||||
printf("Tile: %dx%d BlockSize: %d WarpSize: %d Warps: %d\n",
|
||||
kM,
|
||||
kK,
|
||||
kBlockSize,
|
||||
static_cast<int>(warp_size),
|
||||
kBlockSize / static_cast<int>(warp_size));
|
||||
printf("Each thread: 4 iterations x 8 K-values = 32 elements\n\n");
|
||||
printf("Coalescing: lanes 0-3 read K=0..31 of the same row\n");
|
||||
printf(" (4 x 8 = 32 K-values = one full row)\n\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64, 128, 192}, K=0..7
|
||||
print_thread(0);
|
||||
__syncthreads();
|
||||
// Lane 1: k_chunk=1 → same rows, K=8..15 (coalesced with lane 0)
|
||||
print_thread(1);
|
||||
__syncthreads();
|
||||
// Lane 4: row_in_warp=1 → rows {1, 65, 129, 193}, K=0..7
|
||||
print_thread(4);
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 1 ---\n");
|
||||
__syncthreads();
|
||||
// Warp 1, Lane 0: rows {16, 80, 144, 208}, K=0..7
|
||||
print_thread(static_cast<int>(warp_size));
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 3 (last) ---\n");
|
||||
__syncthreads();
|
||||
// Warp 3, Lane 63: rows {63, 127, 191, 255}, K=24..31
|
||||
print_thread(kBlockSize - 1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
printf("=== CK Tile Distribution Tutorial 1: A-Matrix DRAM Load ===\n");
|
||||
printf("=== Matches naive GEMM: MPerBlock=256, KPerBlock=32 ===\n\n");
|
||||
|
||||
HostTensor<int32_t> h_tensor({kM, kK});
|
||||
for(int i = 0; i < kM * kK; i++)
|
||||
h_tensor.mData[i] = i;
|
||||
|
||||
printf("Host matrix A[%d x %d], row-major, A[m][k] = m*%d + k\n\n", kM, kK, kK);
|
||||
|
||||
DeviceMem d_data(h_tensor);
|
||||
|
||||
launch_kernel(stream_config{},
|
||||
make_kernel<1>(TileDistKernelA{},
|
||||
dim3(1),
|
||||
dim3(TileDistKernelA::kBlockSize),
|
||||
0,
|
||||
static_cast<const int32_t*>(d_data.GetDeviceBuffer())));
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
printf("Done.\n");
|
||||
return 0;
|
||||
}
|
||||
240
tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp
Normal file
240
tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — B Matrix DRAM Load
|
||||
*
|
||||
* Demonstrates how tile_distribution_encoding maps threads to B-matrix
|
||||
* elements during a DRAM load in the naive GEMM tutorial.
|
||||
*
|
||||
* Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp
|
||||
* MakeBDramTileDistribution(), with fp16, BlockSize=256
|
||||
*
|
||||
* Tile: N=128 × K=32 (matches the naive GEMM's B block tile)
|
||||
* Threads: 256 (4 warps on CDNA, 8 on RDNA)
|
||||
*
|
||||
* The B encoding has the SAME structure as the A encoding (Tutorial 1),
|
||||
* but with N=128 instead of M=256. This changes only N0 (the iteration
|
||||
* count), showing how the same encoding pattern adapts to different
|
||||
* tile sizes.
|
||||
*
|
||||
* No compute is performed — this is purely about data movement.
|
||||
*
|
||||
* Note: int32_t is used instead of fp16 for readable printf output.
|
||||
* The distribution encoding is hardcoded to match the fp16 derivation.
|
||||
*
|
||||
* Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32),
|
||||
* the thread-to-data mapping will differ.
|
||||
*/
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <cstdio>
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
// ============================================================================
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// Matrix B: N=128 rows × K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// (In GEMM, B is stored as [N, K] — each "row" is one output channel.)
|
||||
// Load the entire tile into registers using 256 threads (4 warps on CDNA).
|
||||
//
|
||||
// Same coalescing strategy as the A-matrix (Tutorial 1):
|
||||
// - 4 lanes cover one K-row (4 × 8 = 32 K-values)
|
||||
// - Each warp (64 lanes) covers 16 N-rows
|
||||
// - 4 warps cover 64 N-rows per iteration
|
||||
// - N0 = 128/64 = 2 iterations (vs 4 for A's M=256)
|
||||
//
|
||||
// Per-thread buffer = 2 iterations × 8 K-values = 16 elements.
|
||||
//
|
||||
// Compare with Tutorial 1 (A-matrix):
|
||||
// A: M=256, M0=4, buffer=32 | B: N=128, N0=2, buffer=16
|
||||
// Everything else is identical — same K-splitting, same coalescing.
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: tile_distribution_encoding
|
||||
// ============================================================================
|
||||
//
|
||||
// Production code derives (fp16, BlockSize=256, NPerBlock=128, KPerBlock=32):
|
||||
// K1 = 16/sizeof(fp16) = 8
|
||||
// K0 = KPerBlock/K1 = 4
|
||||
// N2 = warp_size/K0 = 16
|
||||
// N1 = BlockSize/warp_size = 4
|
||||
// N0 = NPerBlock/(N2*N1) = 2
|
||||
//
|
||||
// Step 1 — Hierarchical dimensions (Hs):
|
||||
//
|
||||
// Hs[0] = sequence<2, 4, 16> → N = 2 × 4 × 16 = 128
|
||||
// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// [Y0] [P0] [P1] [P1] [Y1]
|
||||
// = 2 = 4 = 16 = 4 = 8
|
||||
// (iter) (warp) (row) (K-chunk) (vec load)
|
||||
//
|
||||
// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
//
|
||||
// Ps_major = tuple<sequence<1>, sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1>, sequence<2, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 2 elements → NDimP=2.
|
||||
// First element = P0 = warp_id
|
||||
// Second element = P1 = lane_id
|
||||
//
|
||||
// P0: major=<1>, minor=<1> → Hs[0], level 1 → N1=4 (which warp)
|
||||
// P1: major=<1,2>, minor=<2,0> → merged:
|
||||
// Hs[0] level 2 → N2=16 (outer, row within warp)
|
||||
// Hs[1] level 0 → K0=4 (inner, K-chunk → coalesced!)
|
||||
// lane / 4 → row_in_warp, lane % 4 → K-chunk
|
||||
//
|
||||
// Step 3 — Yield dimensions (Ys): what each thread owns.
|
||||
//
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 1>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1
|
||||
// ─Y0─ ─Y1─
|
||||
//
|
||||
// Y0: Hs[0] level 0 → N0=2 (iterations along N)
|
||||
// Y1: Hs[1] level 1 → K1=8 (vector load width)
|
||||
//
|
||||
// Buffer size = Y0 × Y1 = 2 × 8 = 16 elements per thread.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
static constexpr index_t kN = 128;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
struct TileDistKernelB
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const int32_t* p_data) const
|
||||
{
|
||||
static_assert(get_warp_size() == 64,
|
||||
"This tutorial is hard-coded for CDNA (warp_size=64). "
|
||||
"On RDNA (warp_size=32), the encoding values and print logic must change.");
|
||||
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data, make_tuple(kN, kK), make_tuple(kK, 1), number<1>{}, number<1>{});
|
||||
|
||||
constexpr auto distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<2, 4, 16>, sequence<4, 8>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
|
||||
auto window = make_tile_window(
|
||||
b_tensor, make_tuple(number<kN>{}, number<kK>{}), {0, 0}, distribution);
|
||||
|
||||
const auto tile = load_tile(window);
|
||||
|
||||
const auto& buf = tile.get_thread_buffer();
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t kBufSize = 16; // 2 iterations × 8 K-values
|
||||
|
||||
int32_t local_buf[kBufSize];
|
||||
static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast<int32_t>(buf[i]); });
|
||||
|
||||
auto print_thread = [&](int tid) {
|
||||
if(static_cast<int>(threadIdx.x) == tid)
|
||||
{
|
||||
int lane = tid % static_cast<int>(warp_size);
|
||||
int warp = tid / static_cast<int>(warp_size);
|
||||
int row_in_wrp = lane / 4;
|
||||
int k_chunk = lane % 4;
|
||||
|
||||
printf("Thread %3d (warp %d, lane %2d) row_in_warp=%2d k_chunk=%d\n",
|
||||
tid,
|
||||
warp,
|
||||
lane,
|
||||
row_in_wrp,
|
||||
k_chunk);
|
||||
|
||||
for(int iter = 0; iter < 2; iter++)
|
||||
{
|
||||
int row = iter * 64 + warp * 16 + row_in_wrp;
|
||||
int col = k_chunk * 8;
|
||||
printf(" iter %d: B[%3d][%2d..%2d] =", iter, row, col, col + 7);
|
||||
for(int k = 0; k < 8; k++)
|
||||
printf(" %4d", local_buf[iter * 8 + k]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("\n=== Tile Distribution: B-Matrix DRAM Load ===\n");
|
||||
printf("Source: MakeBDramTileDistribution (fp16, BlockSize=256)\n");
|
||||
printf("Tile: %dx%d BlockSize: %d WarpSize: %d Warps: %d\n",
|
||||
kN,
|
||||
kK,
|
||||
kBlockSize,
|
||||
static_cast<int>(warp_size),
|
||||
kBlockSize / static_cast<int>(warp_size));
|
||||
printf("Each thread: 2 iterations x 8 K-values = 16 elements\n");
|
||||
printf("Compare with Tutorial 1 (A): same K-split, but N0=2 vs M0=4\n\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64}, K=0..7
|
||||
print_thread(0);
|
||||
__syncthreads();
|
||||
// Lane 1: k_chunk=1 → same rows, K=8..15
|
||||
print_thread(1);
|
||||
__syncthreads();
|
||||
// Lane 4: row_in_warp=1 → rows {1, 65}, K=0..7
|
||||
print_thread(4);
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 1 ---\n");
|
||||
__syncthreads();
|
||||
// Warp 1, Lane 0: rows {16, 80}, K=0..7
|
||||
print_thread(static_cast<int>(warp_size));
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 3 (last) ---\n");
|
||||
__syncthreads();
|
||||
// Warp 3, Lane 63: rows {63, 127}, K=24..31
|
||||
print_thread(kBlockSize - 1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
printf("=== CK Tile Distribution Tutorial 2: B-Matrix DRAM Load ===\n");
|
||||
printf("=== Matches naive GEMM: NPerBlock=128, KPerBlock=32 ===\n\n");
|
||||
|
||||
HostTensor<int32_t> h_tensor({kN, kK});
|
||||
for(int i = 0; i < kN * kK; i++)
|
||||
h_tensor.mData[i] = i;
|
||||
|
||||
printf("Host matrix B[%d x %d], row-major, B[n][k] = n*%d + k\n\n", kN, kK, kK);
|
||||
|
||||
DeviceMem d_data(h_tensor);
|
||||
|
||||
launch_kernel(stream_config{},
|
||||
make_kernel<1>(TileDistKernelB{},
|
||||
dim3(1),
|
||||
dim3(TileDistKernelB::kBlockSize),
|
||||
0,
|
||||
static_cast<const int32_t*>(d_data.GetDeviceBuffer())));
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
printf("Done.\n");
|
||||
return 0;
|
||||
}
|
||||
376
tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp
Normal file
376
tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp
Normal file
@@ -0,0 +1,376 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — C Matrix Register Layout
|
||||
*
|
||||
* Demonstrates how C-matrix elements are distributed across thread registers
|
||||
* after MFMA computation. Unlike A/B (which are DRAM loads), C lives entirely
|
||||
* in registers — the distribution describes which thread holds which output
|
||||
* element of C = A × B.
|
||||
*
|
||||
* This tutorial shows BOTH:
|
||||
* 1. The warp-level C distribution (from MFMA m32n32k8 output mapping)
|
||||
* 2. The block-level outer distribution (how multiple warps tile C)
|
||||
* 3. The composed distribution (what CK actually uses)
|
||||
*
|
||||
* The macro CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION (default 1) selects
|
||||
* between the standard and transposed C register layouts.
|
||||
*
|
||||
* Tile: M=256 × N=128 (matches the naive GEMM's C block tile)
|
||||
* Warp config: MWarp=4, NWarp=1
|
||||
* MFMA: m32n32k8 (each warp produces a 32×32 output)
|
||||
*
|
||||
* No actual MFMA compute — we construct a C distributed tensor, fill it
|
||||
* with marker values (thread_id * 1000 + buffer_index), and print per-thread
|
||||
* contents to reveal which buffer slots belong to which thread.
|
||||
*
|
||||
* Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32),
|
||||
* the thread-to-data mapping will differ.
|
||||
*/
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include <cstdio>
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
// Controls which C register layout to demonstrate
|
||||
#ifndef CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
#define CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION 1
|
||||
#endif
|
||||
|
||||
// ============================================================================
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// After GEMM computation, each thread holds a subset of the C matrix
|
||||
// (M=256 × N=128 = 32768 elements) in its registers. We want to understand
|
||||
// exactly which C[m][n] elements each thread owns.
|
||||
//
|
||||
// The mapping has two levels:
|
||||
//
|
||||
// BLOCK LEVEL (256×128 → warps and iterations):
|
||||
// - 4 warps along M (MWarp=4), 1 warp along N (NWarp=1)
|
||||
// - Each warp covers 32 M-rows × 128 N-cols of the block tile
|
||||
// - Within each warp: MIterPerWarp=2, NIterPerWarp=4
|
||||
// → 2 × 4 = 8 warp-tile iterations per warp
|
||||
// - Each warp-tile iteration is a 32×32 MFMA output
|
||||
//
|
||||
// WARP LEVEL (32×32 → threads):
|
||||
// - 64 threads produce 32 × 32 = 1024 C elements
|
||||
// - Each thread holds 1024/64 = 16 elements
|
||||
// - MFMA m32n32k8 arranges these 16 elements in a specific pattern
|
||||
//
|
||||
// The per-thread register buffer = 8 iterations × 16 elements = 128 floats.
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: Two-Level Distribution
|
||||
// ============================================================================
|
||||
//
|
||||
// --- WARP-LEVEL C DISTRIBUTION (from MFMA m32n32k8) ---
|
||||
//
|
||||
// For fp16→fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2,
|
||||
// kCM1PerLane=4, kCNLane=32):
|
||||
//
|
||||
// STANDARD (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0):
|
||||
//
|
||||
// Hs[0] = sequence<4, 2, 4> → M-dim: 4 × 2 × 4 = 32
|
||||
// Hs[1] = sequence<32> → N-dim: 32
|
||||
// Ps_major = tuple<sequence<1, 2>> → lane maps to Hs[0][1] and Hs[1][0]
|
||||
// Ps_minor = tuple<sequence<1, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 1 element → NDimP=1 → P0 = lane_id.
|
||||
// P0: major=<1,2>, minor=<1,0> → merged:
|
||||
// Hs[0] level 1 → kCMLane=2 (outer, M-half)
|
||||
// Hs[1] level 0 → kCNLane=32 (inner, N-col → contiguous!)
|
||||
// lane / 32 → M-half, lane % 32 → N-col
|
||||
//
|
||||
// Ys_major = sequence<1, 1>
|
||||
// Ys_minor = sequence<0, 2>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 1 > → Y0 is in Hs[0], Y1 is in Hs[0]
|
||||
// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2
|
||||
// ─Y0─ ─Y1─
|
||||
//
|
||||
// Y0: Hs[0] level 0 → kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[0] level 2 → kCM1PerLane=4 (M inner per lane)
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ │
|
||||
// [Y0] [P0] [Y1] [P0]
|
||||
// = 4 = 2 = 4 = 32
|
||||
// (M outer)(lane) (M inner) (lane → N)
|
||||
//
|
||||
// Per-thread: Y0 × Y1 = 4 × 4 = 16 elements per warp-tile.
|
||||
// Lane decomposition: lane / 32 → M-half (0..1), lane % 32 → N-col (0..31)
|
||||
//
|
||||
// TRANSPOSED (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1):
|
||||
//
|
||||
// Hs[0] = sequence<32> → First dim: N (swapped!)
|
||||
// Hs[1] = sequence<4, 2, 4> → Second dim: M (swapped!)
|
||||
// Ps_major = tuple<sequence<2, 1>> → lane maps to Hs[1][1] and Hs[0][0]
|
||||
// Ps_minor = tuple<sequence<1, 0>>
|
||||
//
|
||||
// How to read Ps: tuple has 1 element → NDimP=1 → P0 = lane_id.
|
||||
// P0: major=<2,1>, minor=<1,0> → merged:
|
||||
// Hs[1] level 1 → kCMLane=2 (outer, M-half)
|
||||
// Hs[0] level 0 → kCNLane=32 (inner, N-col → contiguous!)
|
||||
// Same lane decomposition as standard, but dimensions are swapped.
|
||||
//
|
||||
// Ys_major = sequence<2, 2>
|
||||
// Ys_minor = sequence<0, 2>
|
||||
//
|
||||
// How to read Ys:
|
||||
// Ys_major = seq< 2, 2 > → Y0 is in Hs[1], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2
|
||||
// ─Y0─ ─Y1─
|
||||
//
|
||||
// Y0: Hs[1] level 0 → kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[1] level 2 → kCM1PerLane=4 (M inner per lane)
|
||||
// Same 16 elements, but now both Y dims are in Hs[1] (M is second).
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// │ ┌─────┼─────┐
|
||||
// [P0] [Y0] [P0] [Y1]
|
||||
// = 32 = 4 = 2 = 4
|
||||
// (lane → N) (M outer)(lane)(M inner)
|
||||
//
|
||||
// Same 16 elements per thread, but N is the first dimension in the
|
||||
// distribution — this changes which elements are contiguous in the
|
||||
// thread buffer, affecting downstream store coalescing.
|
||||
//
|
||||
// --- BLOCK-LEVEL OUTER DISTRIBUTION ---
|
||||
//
|
||||
// MIterPerWarp = MPerBlock / (MWarp × WarpGemm::kM) = 256 / (4 × 32) = 2
|
||||
// NIterPerWarp = NPerBlock / (NWarp × WarpGemm::kN) = 128 / (1 × 32) = 4
|
||||
//
|
||||
// Hs[0] = sequence<2, 4> → M-dim: 2 iters × 4 warps
|
||||
// Hs[1] = sequence<4, 1> → N-dim: 4 iters × 1 warp
|
||||
// Ps_major = tuple<sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1, 1>>
|
||||
//
|
||||
// How to read Ps: tuple has 1 element → NDimP=1 → P0 = warp_id.
|
||||
// P0: major=<1,2>, minor=<1,1> → merged:
|
||||
// Hs[0] level 1 → MWarp=4 (outer)
|
||||
// Hs[1] level 1 → NWarp=1 (inner, trivial)
|
||||
// Total: 4 × 1 = 4 = number of warps
|
||||
//
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 0>
|
||||
//
|
||||
// How to read Ys:
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 0 > → Y0 is level 0, Y1 is level 0
|
||||
// ─Y0─ ─Y1─
|
||||
//
|
||||
// Y0: Hs[0] level 0 → MIterPerWarp=2
|
||||
// Y1: Hs[1] level 0 → NIterPerWarp=4
|
||||
// Block-level buffer = Y0 × Y1 = 2 × 4 = 8 warp-tile slots.
|
||||
//
|
||||
// tile_distribution_encoding<sequence<>,
|
||||
// tuple<sequence<2, 4>, sequence<4, 1>>,
|
||||
// tuple<sequence<1, 2>>, tuple<sequence<1, 1>>,
|
||||
// sequence<1, 2>, sequence<0, 0>>
|
||||
//
|
||||
// --- COMPOSED (what CK uses) ---
|
||||
//
|
||||
// make_embed_tile_distribution_encoding(block_outer, warp_encoding)
|
||||
// embeds the warp encoding inside each (MIter, MWarp, NIter, NWarp) cell.
|
||||
// Total per-thread buffer = 2 × 4 × 16 = 128 elements.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
static constexpr index_t kM = 256;
|
||||
static constexpr index_t kN = 128;
|
||||
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
using WarpGemm = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution;
|
||||
#else
|
||||
using WarpGemm = WarpGemmMfmaF16F16F32M32N32K8;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kMWarp = 4;
|
||||
static constexpr index_t kNWarp = 1;
|
||||
|
||||
static constexpr index_t kMIterPerWarp = kM / (kMWarp * WarpGemm::kM); // 2
|
||||
static constexpr index_t kNIterPerWarp = kN / (kNWarp * WarpGemm::kN); // 4
|
||||
|
||||
struct TileDistKernelC
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()() const
|
||||
{
|
||||
static_assert(get_warp_size() == 64,
|
||||
"This tutorial is hard-coded for CDNA (warp_size=64). "
|
||||
"On RDNA (warp_size=32), the encoding values and print logic must change.");
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMIterPerWarp, kMWarp>, sequence<kNIterPerWarp, kNWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<float>(c_block_dstr);
|
||||
|
||||
constexpr index_t kBufSize = c_block_tensor.get_thread_buffer_size();
|
||||
|
||||
// Fill each thread's buffer with a marker value:
|
||||
// We can't easily set C[m][n] = m*N + n without knowing the inverse mapping,
|
||||
// so instead we fill with thread_id * 1000 + buffer_index to identify ownership.
|
||||
static_for<0, kBufSize, 1>{}([&](auto i) {
|
||||
c_block_tensor.get_thread_buffer()(i) =
|
||||
static_cast<float>(threadIdx.x * 1000 + static_cast<int>(i));
|
||||
});
|
||||
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
|
||||
// Copy compile-time-indexed buffer into a plain array for runtime printing
|
||||
float local_buf[kBufSize];
|
||||
static_for<0, kBufSize, 1>{}(
|
||||
[&](auto i) { local_buf[i] = c_block_tensor.get_thread_buffer()[i]; });
|
||||
|
||||
auto print_thread = [&](int tid) {
|
||||
if(static_cast<int>(threadIdx.x) == tid)
|
||||
{
|
||||
int lane = tid % static_cast<int>(warp_size);
|
||||
int warp = tid / static_cast<int>(warp_size);
|
||||
|
||||
printf("Thread %3d (warp %d, lane %2d) buf_size=%d\n",
|
||||
tid,
|
||||
warp,
|
||||
lane,
|
||||
static_cast<int>(kBufSize));
|
||||
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
printf(" Layout: TRANSPOSED (N is first dimension)\n");
|
||||
#else
|
||||
printf(" Layout: STANDARD (M is first dimension)\n");
|
||||
#endif
|
||||
|
||||
printf(" Block-level: MIterPerWarp=%d, NIterPerWarp=%d\n",
|
||||
static_cast<int>(kMIterPerWarp),
|
||||
static_cast<int>(kNIterPerWarp));
|
||||
printf(" Warp-level: 16 elements per warp-tile (32x32 MFMA output)\n");
|
||||
printf(" Total: %d x %d x 16 = %d elements\n",
|
||||
static_cast<int>(kMIterPerWarp),
|
||||
static_cast<int>(kNIterPerWarp),
|
||||
static_cast<int>(kBufSize));
|
||||
|
||||
constexpr int kPerWarpTile = 16;
|
||||
for(int mIter = 0; mIter < static_cast<int>(kMIterPerWarp); mIter++)
|
||||
{
|
||||
for(int nIter = 0; nIter < static_cast<int>(kNIterPerWarp); nIter++)
|
||||
{
|
||||
int base = (mIter * static_cast<int>(kNIterPerWarp) + nIter) * kPerWarpTile;
|
||||
printf(" [mIter=%d, nIter=%d] buf[%3d..%3d]:",
|
||||
mIter,
|
||||
nIter,
|
||||
base,
|
||||
base + kPerWarpTile - 1);
|
||||
for(int k = 0; k < kPerWarpTile; k++)
|
||||
{
|
||||
printf(" %.0f", static_cast<double>(local_buf[base + k]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("\n=== Tile Distribution: C-Matrix Register Layout ===\n");
|
||||
printf("Tile: %dx%d BlockSize: %d WarpSize: %d\n",
|
||||
static_cast<int>(kM),
|
||||
static_cast<int>(kN),
|
||||
static_cast<int>(kBlockSize),
|
||||
static_cast<int>(warp_size));
|
||||
printf("MWarp=%d, NWarp=%d, MFMA=m32n32k8\n",
|
||||
static_cast<int>(kMWarp),
|
||||
static_cast<int>(kNWarp));
|
||||
printf("MIterPerWarp=%d, NIterPerWarp=%d\n",
|
||||
static_cast<int>(kMIterPerWarp),
|
||||
static_cast<int>(kNIterPerWarp));
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
printf("Mode: TRANSPOSED C (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1)\n");
|
||||
printf(" WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution\n");
|
||||
printf(" Warp encoding: <seq<>, tuple<seq<32>, seq<4,2,4>>,\n");
|
||||
printf(" tuple<seq<2,1>>, tuple<seq<1,0>>,\n");
|
||||
printf(" seq<2,2>, seq<0,2>>\n");
|
||||
#else
|
||||
printf("Mode: STANDARD C (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0)\n");
|
||||
printf(" WarpGemmMfmaF16F16F32M32N32K8\n");
|
||||
printf(" Warp encoding: <seq<>, tuple<seq<4,2,4>, seq<32>>,\n");
|
||||
printf(" tuple<seq<1,2>>, tuple<seq<1,0>>,\n");
|
||||
printf(" seq<1,1>, seq<0,2>>\n");
|
||||
#endif
|
||||
printf("\nBlock outer: <seq<>, tuple<seq<%d,%d>, seq<%d,%d>>,\n",
|
||||
static_cast<int>(kMIterPerWarp),
|
||||
static_cast<int>(kMWarp),
|
||||
static_cast<int>(kNIterPerWarp),
|
||||
static_cast<int>(kNWarp));
|
||||
printf(" tuple<seq<1,2>>, tuple<seq<1,1>>,\n");
|
||||
printf(" seq<1,2>, seq<0,0>>\n\n");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Warp 0, Lane 0
|
||||
print_thread(0);
|
||||
__syncthreads();
|
||||
// Warp 0, Lane 32 (different M-half in standard, different N in transposed)
|
||||
print_thread(32);
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 1 (covers different M-rows than warp 0) ---\n");
|
||||
__syncthreads();
|
||||
print_thread(static_cast<int>(warp_size));
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("\n--- Warp 3 (last) ---\n");
|
||||
__syncthreads();
|
||||
print_thread(kBlockSize - 1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
printf("=== CK Tile Distribution Tutorial 3: C-Matrix Register Layout ===\n");
|
||||
printf("=== Matches naive GEMM: MPerBlock=256, NPerBlock=128 ===\n\n");
|
||||
printf("MFMA m32n32k8: each warp produces 32x32 = 1024 elements\n");
|
||||
printf(" 64 threads per warp → 16 elements per thread per warp-tile\n");
|
||||
printf(" MWarp=4, NWarp=1 → 4 warps along M, 1 along N\n");
|
||||
printf(" MIterPerWarp=2, NIterPerWarp=4 → 8 warp-tiles per warp\n");
|
||||
printf(" Total per thread: 8 × 16 = 128 elements\n\n");
|
||||
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
printf("Current mode: TRANSPOSED C distribution\n");
|
||||
printf(" Rebuild with -DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0 for standard\n\n");
|
||||
#else
|
||||
printf("Current mode: STANDARD C distribution\n");
|
||||
printf(" Rebuild with -DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1 for transposed\n\n");
|
||||
#endif
|
||||
|
||||
launch_kernel(stream_config{},
|
||||
make_kernel<1>(TileDistKernelC{}, dim3(1), dim3(TileDistKernelC::kBlockSize), 0));
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
printf("Done.\n");
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user