Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038)

* feat: add tutorial folder with gemm tutorial

* chore: move copy kernel from examples folder to tutorial

* Update tutorial/ck_tile/01_naive_gemm/README.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tutorial/ck_tile/01_naive_gemm/README.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* chore: remove handdrawn images

* docs: add write ups to explain the gemm kernel

* docs: add about block level pipeline and static distributed tensors

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-11-11 15:15:49 -05:00
committed by GitHub
parent c54ecd905b
commit b145a5fe80
24 changed files with 3287 additions and 15 deletions

View File

@@ -0,0 +1,589 @@
# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg
## Overview
The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates:
1. **Data movement** from DRAM → Registers → LDS
2. **GEMM computation** using data in LDS
3. **Iteration** over the K dimension when needed
This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C.
---
## Architecture: Problem and Policy
Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern:
### Problem: `PracticeGemmBlockPipelineProblem`
Contains:
- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType`
- **Shape information**: `BlockTile` and `WaveTile` dimensions
### Policy: `PracticeGemmBlockPolicy`
Contains strategies for:
1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`)
- Defines how 256 threads in a block map to elements of a block tile
- Each thread knows which elements to load/store from DRAM to its registers
- We'll cover tile distribution construction in detail later
2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`)
- Describes how data is logically organized in Local Data Share (LDS)
- Optimizes for bank conflict avoidance and efficient access patterns
- We'll cover LDS descriptor construction in detail later
3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`)
- Returns the warp-level GEMM implementation
---
## Inputs and Outputs
```cpp
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
```
### Inputs:
- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock)
- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock)
- `num_loop`: Number of iterations along K dimension
- `p_smem`: Pointer to shared memory (LDS)
### Output:
- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs)
---
## Step-by-Step Walkthrough
### Step 1: Create LDS Tensor Views
```cpp
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// B tile in LDS (placed after A in shared memory)
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
```
**What's happening:**
- We partition the shared memory (`p_smem`) into two regions: one for A, one for B
- We create **tensor views** over these LDS regions using descriptors from the policy
- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory
**Memory Layout:**
```
Shared Memory (LDS):
┌─────────────────────┬─────────────────────┐
│ A Block Tile │ B Block Tile │
│ (256×32 fp16) │ (128×32 fp16) │
└─────────────────────┴─────────────────────┘
↑ ↑
p_a_lds p_b_lds
```
---
### Step 2: Create Tile Windows for Data Movement
We create **6 tile windows** for different purposes:
#### 2a. DRAM → Registers (Load from DRAM)
```cpp
auto a_copy_dram_window = make_tile_window(
a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); // ← Tile distribution!
```
**Key Points:**
- `a_copy_dram_window` is a `tile_window_with_static_distribution`
- The **tile distribution** tells each thread which elements to load from DRAM
- This window will **slide along the K dimension** in the loop
#### 2b. Registers → LDS (Store to LDS)
```cpp
auto a_copy_lds_window = make_tile_window(
a_lds_block,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
{0, 0}, // Origin at (0, 0) in LDS
a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM!
```
**Key Points:**
- Uses the **same tile distribution** as `a_copy_dram_window`
- This ensures each thread stores to LDS in the same pattern it loaded from DRAM
- Origin is always `{0, 0}` because LDS is reused for each K iteration
#### 2c. LDS → Registers (GEMM Input)
```cpp
auto a_lds_gemm_window = make_tile_window(
a_lds_block,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0}); // No tile distribution!
```
**Key Points:**
- This is a `tile_window_with_static_lengths` (no explicit distribution)
- Used as input to the warp-level GEMM
- The warp GEMM will handle its own thread mapping internally
**Similar windows are created for B:**
- `b_copy_dram_window`: Load B from DRAM
- `b_copy_lds_window`: Store B to LDS
- `b_lds_gemm_window`: Read B from LDS for GEMM
---
### Step 3: Create Distributed Tensors (VGPRs)
```cpp
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile; // Per-thread registers for A
BBlockTile b_block_tile; // Per-thread registers for B
```
#### What is `make_static_distributed_tensor`?
**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**.
**Key Properties:**
1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers
2. **Compile-time sized**: Buffer size determined by tile distribution at compile time
3. **Zero-overhead**: All indexing and layout transformations happen at compile time
**How it works:**
```cpp
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor
{
using DataType = remove_cvref_t<DataType_>;
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
// Calculate per-thread storage size from tile distribution
using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize =
ThreadTensorDesc{}.get_element_space_size();
// Per-thread register array (VGPRs)
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
};
```
**The tile distribution defines:**
- **Which elements each thread owns** in the tile
- **How many elements** each thread stores (buffer size)
- **How elements are laid out** in each thread's registers
**Concrete Example for 256×32 tile with 256 threads:**
```
Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values)
Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values)
Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values)
...
Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values)
```
**Collectively:**
- All 256 threads together hold the **entire 256×32 tile** (8192 elements)
- Each thread's buffer lives in its **own VGPRs**
- No two threads own the same element
**Distributed Ownership Analogy:**
Think of a tile as a **jigsaw puzzle**:
- The **tile distribution** is the cutting pattern
- Each **thread** gets one puzzle piece (its slice)
- Each **`static_distributed_tensor`** is a box holding all pieces
- Each thread's **`thread_buf_`** is its individual piece in its own registers
---
### Step 4: The GEMM Loop
```cpp
// Initialize C accumulator to zero
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
index_t iCounter = num_loop; // Number of K iterations
while(iCounter > 0)
{
// 1. Load from DRAM to registers
a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs
b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs
// 2. Move windows for next iteration
move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32)
move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32)
// 3. Store from registers to LDS
store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS
store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS
// 4. Synchronize threads (ensure all data is in LDS)
block_sync_lds();
// 5. Compute GEMM using data in LDS
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// 6. Synchronize threads (before overwriting LDS in next iteration)
block_sync_lds();
iCounter--;
}
return c_block_tile; // Return accumulated result in registers
```
---
## Detailed Loop Breakdown
### Phase 1: Load (DRAM → VGPRs)
```cpp
a_block_tile = load_tile(a_copy_dram_window);
```
**What happens:**
1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution)
2. Data is loaded into **per-thread registers** (VGPRs)
3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once)
**Example for Thread 0:**
```
Thread 0 loads:
A[0,0:7] (8 fp16 values, one vector load)
A[1,0:7] (8 fp16 values, one vector load)
...
```
### Phase 2: Move Windows
```cpp
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
```
**What happens:**
- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example)
- This prepares for the next K iteration
- The window origin moves from `(0, 0)``(0, 32)``(0, 64)` → ...
**Visualization for Problem Size 512×256×64:**
```
Matrix A (512×64):
┌─────────────────────────────────────┐
│ Block 0: rows 0-255 │
│ ┌──────────┬──────────┐ │
│ │ K=0:31 │ K=32:63 │ │ ← Window slides right
│ │ Iter 0 │ Iter 1 │ │
│ └──────────┴──────────┘ │
└─────────────────────────────────────┘
```
### Phase 3: Store (VGPRs → LDS)
```cpp
store_tile(a_copy_lds_window, a_block_tile);
```
**What happens:**
1. Each thread writes **its elements** from registers to LDS
2. Uses the **same distribution** as the DRAM load
3. Data is now in **shared memory**, accessible to all threads in the block
**Why this step?**
- GEMM computation needs **all threads** to access **all data**
- Registers are per-thread; LDS is shared across the block
- LDS acts as a "staging area" for collaborative computation
### Phase 4: Synchronize
```cpp
block_sync_lds();
```
**What happens:**
- All threads in the block **wait** until everyone has finished storing to LDS
- Ensures no thread starts reading from LDS before all writes are complete
- Critical for correctness!
### Phase 5: GEMM Computation
```cpp
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
```
**What happens:**
1. The warp-level GEMM reads data from LDS
2. Performs matrix multiplication using MFMA instructions
3. Accumulates results into `c_block_tile` (in registers)
**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results.
### Phase 6: Synchronize Again
```cpp
block_sync_lds();
```
**What happens:**
- Ensures all threads have finished reading from LDS
- Safe to overwrite LDS in the next iteration
---
## Memory Flow Diagram
```
Iteration 0 (K=0:31):
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
│ A[0:255,│ │ (per- │ │ A_block │
│ 0:31] │ │ thread) │ │ │
└─────────┘ └──────────┘ └─────────┘
│ block_gemm
┌──────────┐
│ c_block_ │
│ tile │
│ (VGPRs) │
└──────────┘
Iteration 1 (K=32:63):
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
│ A[0:255,│ │ (per- │ │ A_block │
│ 32:63] │ │ thread) │ │ (reused)│
└─────────┘ └──────────┘ └─────────┘
│ block_gemm
┌──────────┐
│ c_block_ │
│ tile │
│ (accum.) │
└──────────┘
```
---
## Example: Problem Size 512×256×64
### Block 0 Computation
**Input:**
- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially
- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed)
- `num_loop`: 2 (since K=64, KPerBlock=32)
**Iteration 0:**
1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs
2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63]
3. Store to LDS
4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T`
**Iteration 1:**
1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs
2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends)
3. Store to LDS
4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T`
**Output:**
- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers
---
## Key Concepts Summary
### 1. Tile Distribution
- **Maps threads to data elements** for load/store operations
- Each thread knows exactly which elements it's responsible for
- Enables **parallel, vectorized** memory access
- **Same distribution** used for DRAM load and LDS store
### 2. Static Distributed Tensor
- **Per-thread register storage** (VGPRs)
- Each thread owns a **different slice** of the tile
- **Compile-time sized** for zero-overhead abstraction
- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile`
### 3. Tile Window Movement
- Windows **slide** over larger tensors
- Enables iteration over the K dimension
- `move_tile_window(window, step)` updates the origin
### 4. LDS as Staging Area
- **Shared memory** accessible to all threads in a block
- Required because GEMM needs all threads to access all data
- **Reused** across K iterations (same LDS buffer)
### 5. Synchronization
- `block_sync_lds()` ensures memory consistency
- **Before GEMM**: All stores to LDS are complete
- **After GEMM**: All reads from LDS are complete
---
## Deep Dive: `static_distributed_tensor` Mechanics
### How Tile Distribution Creates Per-Thread Storage
When you call:
```cpp
using ABlockTile = decltype(make_static_distributed_tensor<fp16_t>(ABlockTileDistr{}));
ABlockTile a_block_tile;
```
**Step 1: Extract Thread Tensor Descriptor**
The tile distribution contains a `ys_to_d_descriptor` that maps:
- **Y dimensions** (logical tile coordinates, e.g., M, K)
- **D dimension** (per-thread register index, linearized)
```cpp
using ThreadTensorDesc =
decltype(StaticTileDistribution{}.get_ys_to_d_descriptor());
```
**Step 2: Calculate Per-Thread Buffer Size**
```cpp
static constexpr index_t kThreadElementSpaceSize =
ThreadTensorDesc{}.get_element_space_size();
static constexpr index_t get_thread_buffer_size()
{
return kThreadElementSpaceSize / PackedSize;
}
```
**Example:**
- 256×32 tile distributed across 256 threads
- Each thread owns 32 elements (one row)
- `thread_buffer_size = 32` (for PackedSize=1)
**Step 3: Allocate Thread Buffer**
```cpp
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
```
This is essentially:
```cpp
fp16_t data[32]; // Per-thread register array (VGPRs)
```
### Usage in Load/Store Operations
**Load from DRAM:**
```cpp
a_block_tile = load_tile(a_copy_dram_window);
```
What happens internally:
1. Each thread queries the tile distribution: "Which elements do I own?"
2. Thread 0 learns it owns A[0,0:31]
3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]`
4. All 256 threads do this **in parallel**
**Store to LDS:**
```cpp
store_tile(a_copy_lds_window, a_block_tile);
```
What happens internally:
1. Each thread reads from its `a_block_tile.thread_buf_`
2. Thread 0 writes A[0,0:31] from its registers to LDS
3. All 256 threads do this **in parallel**
4. After `block_sync_lds()`, the entire tile is in shared LDS
### Distributed Indexing
The `static_distributed_tensor` supports compile-time indexing:
```cpp
// Access using distributed indices
auto value = a_block_tile(tile_distributed_index<i, j>{});
```
Internally:
1. Convert distributed index → Y index (logical tile coordinates)
2. Calculate buffer offset using `ThreadTensorDesc`
3. Access `thread_buf_[offset]`
All of this happens **at compile time** with zero runtime overhead!
### Why This Design?
**Benefits:**
1. **Parallel Memory Access**: All threads load/store simultaneously
2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once)
3. **Zero Overhead**: All indexing resolved at compile time
4. **Type Safety**: Distribution mismatch caught at compile time
5. **Register Pressure**: Compiler knows exact VGPR usage
**Trade-offs:**
- Requires compile-time tile sizes
- Distribution must be static
- More complex type system
### Memory Hierarchy Summary
```
┌─────────────────────────────────────────────────────────────┐
│ DRAM (Global Memory) │
│ Full matrices A, B, C │
└─────────────────────────────────────────────────────────────┘
│ load_tile (parallel, vectorized)
┌─────────────────────────────────────────────────────────────┐
│ VGPRs (Per-Thread Registers) │
│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │
│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │
│ ... │
│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │
│ │
│ ← static_distributed_tensor manages this distribution │
└─────────────────────────────────────────────────────────────┘
│ store_tile (parallel, vectorized)
┌─────────────────────────────────────────────────────────────┐
│ LDS (Shared Memory) │
│ Entire block tile (256×32) │
│ Accessible to all threads in block │
└─────────────────────────────────────────────────────────────┘
```
**Key Insight:**
`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time.

View File

@@ -0,0 +1,7 @@
add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp)
target_compile_options(tile_tutorial_naive_gemm PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
add_dependencies(tutorials tile_tutorial_naive_gemm)

View File

@@ -0,0 +1,618 @@
# Host-Level Pipeline: Orchestrating Block-Level GEMM
This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation.
## Overview
The host-level pipeline is responsible for:
1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C
2. **Block-to-tile mapping**: Assigning each thread block to a specific tile
3. **Creating tile windows**: Establishing sliding windows over tensor views
4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM
5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM
```cpp
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
struct PracticeGemmHostPipeline
{
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
const BDRAMTensorView& b_dram,
CDRAMTensorView& c_dram) const
{
// 1. Calculate problem dimensions and tile coverage
// 2. Map thread block to tile coordinates
// 3. Create tile windows over A and B
// 4. Call block-level pipeline to compute
// 5. Store result to C
}
};
```
---
## Step 1: Calculate Problem Dimensions and Tile Coverage
```cpp
// Size of the entire problem
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
// Size of the block tile
const auto MPerBlock = BlockTile::at(number<0>{}); // 256
const auto NPerBlock = BlockTile::at(number<1>{}); // 128
const auto KPerBlock = BlockTile::at(number<2>{}); // 32
// Number of block tiles needed to cover C matrix
const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2
const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2
```
### What's Happening:
1. **Extract problem dimensions** from tensor descriptors:
- `M = 512`: Rows in A and C
- `N = 256`: Columns in B and C
- `K = 64`: Inner dimension (columns of A, rows of B)
2. **Get block tile sizes** from the `BlockTile` configuration:
- `MPerBlock = 256`: Each block processes 256 rows
- `NPerBlock = 128`: Each block processes 128 columns
- `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration
3. **Calculate tile coverage**:
- `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction
- `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction
- **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**!
### Visual Representation:
```
Matrix C (512 × 256):
┌──────────────────────┬──────────────────────┐
│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2
│ 256×128 │ 256×128 │
│ Block 0 │ Block 1 │
│ │ │
├──────────────────────┼──────────────────────┤
│ Tile (1,0) │ Tile (1,1) │
│ 256×128 │ 256×128 │
│ Block 2 │ Block 3 │
│ │ │
└──────────────────────┴──────────────────────┘
num_tile_m = 2
Total blocks needed = 2 × 2 = 4 blocks
Each block computes one 256×128 tile of the output matrix C.
```
### How Blocks Cover Matrices A and B:
```
Matrix A (512 × 64): Matrix B (256 × 64):
┌─────────────┬──────┐ ┌─────────────┬──────┐
│ Block 0,2 │ K │ │ Block 0,1 │ K │
│ uses rows │ → │ │ uses rows │ → │
│ 0-255 │ │ │ 0-127 │ │
├─────────────┼──────┤ ├─────────────┼──────┤
│ Block 1,3 │ K │ │ Block 2,3 │ K │
│ uses rows │ → │ │ uses rows │ → │
│ 256-511 │ │ │ 128-255 │ │
└─────────────┴──────┘ └─────────────┴──────┘
256 rows 64 cols 128 rows 64 cols
Each block needs to iterate over K dimension (64/32 = 2 iterations)
```
---
## Step 2: Map Thread Block to Tile Coordinates
```cpp
// Get block id (0 to total_blocks - 1)
const auto id_block = get_block_id();
// Map block id to 2D tile coordinates
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
const auto tile_id = block2tile(id_block);
const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate
const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate
```
### What's Happening:
Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates.
### The `MakeBlock2TileMap` Function:
```cpp
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
// Create a merge transform: (N0, M0) → linear index
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
return [unmerge](index_t block_id) {
multi_index<2> unmerged;
// Convert linear block_id back to 2D coordinates
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
// Return (m_idx, n_idx) - note the swap!
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
}
```
### In Our Example (2×2 Grid):
```cpp
// Block 0:
id_block = 0
tile_id = block2tile(0) = (0, 0) // Top-left tile
tile_id_m = 0, tile_id_n = 0
// Block 1:
id_block = 1
tile_id = block2tile(1) = (1, 0) // Bottom-left tile
tile_id_m = 1, tile_id_n = 0
// Block 2:
id_block = 2
tile_id = block2tile(2) = (0, 1) // Top-right tile
tile_id_m = 0, tile_id_n = 1
// Block 3:
id_block = 3
tile_id = block2tile(3) = (1, 1) // Bottom-right tile
tile_id_m = 1, tile_id_n = 1
```
**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing!
---
## Step 3: Calculate Tile Origin and Create Tile Windows
```cpp
// Calculate the starting position of this tile in the global matrix
const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256
const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128
// Create tile windows over A and B tensor views
const auto a_block_window = make_tile_window(
a_dram, // Tensor view over A
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // Window size: 256×32
{tile_origin_m, 0} // Origin: varies by block
);
const auto b_block_window = make_tile_window(
b_dram, // Tensor view over B
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), // Window size: 128×32
{tile_origin_n, 0} // Origin: varies by block
);
```
### Tile Origins for Each Block:
```cpp
// Block 0 (Tile 0,0):
tile_origin_m = 0 * 256 = 0
tile_origin_n = 0 * 128 = 0
a_block_window origin: (0, 0) covers A rows 0-255
b_block_window origin: (0, 0) covers B rows 0-127
// Block 1 (Tile 1,0):
tile_origin_m = 1 * 256 = 256
tile_origin_n = 0 * 128 = 0
a_block_window origin: (256, 0) covers A rows 256-511
b_block_window origin: (0, 0) covers B rows 0-127
// Block 2 (Tile 0,1):
tile_origin_m = 0 * 256 = 0
tile_origin_n = 1 * 128 = 128
a_block_window origin: (0, 0) covers A rows 0-255
b_block_window origin: (128, 0) covers B rows 128-255
// Block 3 (Tile 1,1):
tile_origin_m = 1 * 256 = 256
tile_origin_n = 1 * 128 = 128
a_block_window origin: (256, 0) covers A rows 256-511
b_block_window origin: (128, 0) covers B rows 128-255
```
### What are Tile Windows?
A **tile window** is a **sliding window** over a larger tensor view. It:
- Defines a **rectangular region** within the tensor
- Has a **fixed size** (e.g., 256×32 for A)
- Has an **origin** (starting position)
- Can be **moved** to access different regions
### Visual Representation (Block 0 Example):
```
Matrix A (512 × 64): Matrix B (256 × 64):
┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐
│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │
│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │
│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │
│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │
│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │
│ │ │ ├─────────────┼─────────────┤
├─────────────┼─────────────┤ │ │ │
│ │ │ │ │ │
│ │ │ │ │ │
│ │ │ │ │ │
└─────────────┴─────────────┘ └─────────────┴─────────────┘
Origin: (0, 0) Origin: (0, 0)
Covers rows 0-255 Covers rows 0-127
Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration)
```
**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration.
### Tile Window Properties:
```cpp
// Tile window structure (conceptual):
struct tile_window {
TensorView& tensor_view; // Reference to underlying tensor
Tuple window_lengths; // Size of the window (256, 32)
MultiIndex window_origin; // Starting position (0, 0)
// Can move the window:
void move(MultiIndex step); // Shift window by step
// Access data through the window:
auto load(); // Load data from windowed region
};
```
### Tile Window Movement: Iterating Over K Dimension
In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension:
```
Matrix A (512 × 64) - Block 0's view:
┌─────────────┬─────────────┐
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K
│ ┃ 256×32 ┃ │ ║ 256×32 ║ │
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
├─────────────┼─────────────┤
│ │ │
│ Block 1's │ │
│ region │ │
└─────────────┴─────────────┘
Matrix B (256 × 64) - Block 0's view:
┌─────────────┬─────────────┐
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │
│ ┃ 128×32 ┃ │ ║ 128×32 ║ │
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
├─────────────┼─────────────┤
│ Block 2's │ │
│ region │ │
└─────────────┴─────────────┘
```
### How Windows Move (Conceptual - handled by block pipeline):
```cpp
// Iteration 0:
a_block_window origin: (tile_origin_m, 0) // K columns 0-31
b_block_window origin: (tile_origin_n, 0) // K columns 0-31
// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31]
// Move windows to next K position:
move_tile_window(a_block_window, {0, 32});
move_tile_window(b_block_window, {0, 32});
// Iteration 1:
a_block_window origin: (tile_origin_m, 32) // K columns 32-63
b_block_window origin: (tile_origin_n, 32) // K columns 32-63
// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63]
// Final result:
// C_tile = C_partial_0 + C_partial_1
```
**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C.
---
## Step 4: Delegate to Block-Level Pipeline
```cpp
// Get the block-level pipeline from policy
constexpr auto block_gemm_pipeline =
Policy::template GetPracticeGemmBlockPipeline<Problem>();
// Calculate number of K iterations needed
int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2
// Allocate shared memory (LDS) for block-level computation
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
// Call block-level pipeline to compute C tile
const auto c_block_tile =
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
```
### What's Happening:
1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation
2. **Calculate K iterations**: How many times to iterate over the K dimension
- In our example: `K=64, KPerBlock=32`**2 iterations**
- Each iteration processes 32 elements of the K dimension
- Results are accumulated across iterations
3. **Allocate shared memory**:
- `__shared__` declares memory shared by all threads in the block
- `GetStaticLDSSize()` returns the required size in bytes
- This memory is used for:
- Staging data from DRAM → LDS
- Cooperative loading by threads
- Fast access during computation
4. **Execute block pipeline**:
- Takes A and B tile windows as input
- Performs the GEMM computation: `C_tile = A_tile × B_tile`
- Returns result in `c_block_tile` (stored in VGPRs - registers)
### Memory Hierarchy During Computation:
```
┌─────────────────────────────────────────────────────────────┐
│ DRAM (Global Memory) - Slowest, Largest │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ A matrix │ │ B matrix │ │ C matrix │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
↓ load ↓ load ↑ store
┌─────────────────────────────────────────────────────────────┐
│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ A_tile │ │ B_tile │ ← Staged here │
│ │ (p_smem) │ │ (p_smem) │ │
│ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
↓ load ↓ load
┌─────────────────────────────────────────────────────────────┐
│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ c_block_tile (accumulated result) │ │
│ │ Computation happens here using MFMA instructions │ │
│ └─────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
### Block Pipeline Responsibilities:
The block pipeline (called here) will:
1. Load A and B tiles from DRAM → LDS (cooperative loading)
2. Distribute work among warps
3. Each warp loads its portion from LDS → VGPRs
4. Perform MFMA operations: `C += A × B`
5. Accumulate results in VGPRs
6. Return final `c_block_tile` in registers
---
## Step 5: Store Results to DRAM
```cpp
// Create a tile window over C for writing results
auto c_window = make_tile_window(
c_dram, // Tensor view over C
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}), // Window size: 256×128
{tile_origin_m, tile_origin_n} // Origin: varies by block
);
// Store computed tile from VGPRs to DRAM
store_tile(c_window, c_block_tile);
```
### C Window Origins for Each Block:
```cpp
// Block 0: Writes to top-left tile
c_window origin: (0, 0) writes to C[0:255, 0:127]
// Block 1: Writes to bottom-left tile
c_window origin: (256, 0) writes to C[256:511, 0:127]
// Block 2: Writes to top-right tile
c_window origin: (0, 128) writes to C[0:255, 128:255]
// Block 3: Writes to bottom-right tile
c_window origin: (256, 128) writes to C[256:511, 128:255]
```
### What's Happening:
1. **Create C tile window**:
- Size: 256×128 (matches our block tile size)
- Origin: Varies by block - each block writes to its assigned region
- This window defines **where** to write the results
2. **Store tile to DRAM**:
- `c_block_tile`: Computed results in VGPRs (registers)
- `c_window`: Destination window in DRAM
- `store_tile()`: Efficiently writes data from registers → DRAM
### The `store_tile` Function:
Recall from our earlier discussion, `store_tile` does:
```cpp
template <typename TileWindow, typename DistributedTensor>
void store_tile(TileWindow& tile_window_tmp,
const DistributedTensor& dstr_tensor)
{
// 1. Extract tile distribution from distributed tensor
using TileDstr = typename DistributedTensor::TileDistribution;
// 2. Upgrade simple tile window to one with distribution
auto tile_window = make_tile_window(
tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
TileDstr{} // Add distribution info
);
// 3. Store using vectorized writes
tile_window.store(dstr_tensor);
}
```
### Memory Flow:
```
VGPRs (Registers) DRAM (Global Memory)
┌─────────────────────┐ ┌─────────────────────┐
│ c_block_tile │ │ C matrix │
│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │
│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │
│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │
│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │
│ └───┴───┴───┴───┘ │ │ │ │ │
│ Distributed across │ │ └───────────────┘ │
│ threads/warps │ │ Origin: (0, 0) │
└─────────────────────┘ └─────────────────────┘
Each thread writes its portion using vector stores (e.g., float4)
```
### Store Optimization:
The `store_tile` function:
- Uses **vectorized stores** (write multiple elements at once)
- Ensures **coalesced memory access** (adjacent threads write adjacent memory)
- Respects **tile distribution** (each thread knows what data it owns)
- Handles **out-of-bounds** checking (for partial tiles at boundaries)
---
## Complete Flow Visualization
Let's trace the complete flow for **Block 0** (other blocks follow the same pattern):
```
┌─────────────────────────────────────────────────────────────────┐
│ Step 1: Calculate Tile Coverage │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ M=512, N=256, K=64 │ │
│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │
│ │ num_tile_m = ceil(512/256) = 2 │ │
│ │ num_tile_n = ceil(256/128) = 2 │ │
│ │ Total blocks needed = 2 × 2 = 4 blocks │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Step 2: Map Block to Tile (Block 0 example) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Block ID: 0 │ │
│ │ Tile coordinates: (0, 0) - top-left tile │ │
│ │ Tile origin: (0, 0) │ │
│ │ │ │
│ │ (Blocks 1,2,3 get different tile coordinates) │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Step 3: Create Tile Windows │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ a_block_window: 256×32 starting at (0,0) over A │ │
│ │ b_block_window: 128×32 starting at (0,0) over B │ │
│ │ Windows initially cover K columns 0-31 │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Step 4: Execute Block Pipeline (2 K iterations) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Allocate shared memory (LDS) │ │
│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │
│ │ │ │
│ │ K Iteration 0 (K=0-31): │ │
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │
│ │ └─ Move windows: {0, 32} │ │
│ │ │ │
│ │ K Iteration 1 (K=32-63): │ │
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │
│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │
│ │ │ │
│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Step 5: Store Results │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Create c_window: 256×128 starting at (0,0) over C │ │
│ │ store_tile(c_window, c_block_tile) │ │
│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │
│ │ │ │
│ │ Block 0 writes to C[0:255, 0:127] │ │
│ │ (Other blocks write to their respective regions) │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
All 4 blocks execute in parallel, each computing its assigned 256×128 tile!
```
---
## Key Concepts Summary
### 1. **Tile Coverage**
- Determines how many thread blocks are needed
- Each block processes one tile of the output matrix C
- Calculated as `ceil(dimension / tile_size)`
### 2. **Block-to-Tile Mapping**
- Maps linear block ID to 2D tile coordinates
- Uses column-major ordering for better memory coalescing
- Each block knows which tile it's responsible for
### 3. **Tile Windows**
- **Sliding windows** over larger tensor views
- Define a rectangular region with fixed size and movable origin
- Provide efficient, structured access to tensor data
- Can be moved to access different regions (e.g., for K iterations)
### 4. **Memory Hierarchy**
- **DRAM (Global)**: Largest, slowest - stores full matrices
- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access
- **VGPRs (Registers)**: Smallest, fastest - performs computation
### 5. **Data Flow**
```
DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM
↑ ↓
A, B matrices C matrix
```
---
## Next Steps
The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore:
- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed
- **Warp-level pipeline**: How warps perform MFMA operations
- **Memory optimization**: LDS usage, bank conflicts, coalescing
The host level provides the **orchestration**, while the block and warp levels provide the **execution**!

View File

@@ -0,0 +1,464 @@
# PracticeGemmKernel: Understanding the Kernel Entry Point
This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views.
## Overview
The `PracticeGemmKernel` is a templated struct that:
1. Takes raw device memory pointers for matrices A, B, and C
2. Wraps them into **tensor views** - logical, structured views over physical memory
3. Dispatches to the host-level pipeline for computation
```cpp
template <typename Problem_, typename Policy_>
struct PracticeGemmKernel
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = 256;
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
const typename Problem::BDataType* p_b,
typename Problem::CDataType* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c) const
{
// Step 1: Create tensor views over raw memory
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
// Step 2: Dispatch to host-level pipeline
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
}
};
```
---
## What are Tensor Views?
A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor.
### Key Components of a Tensor View:
1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers)
2. **Raw Pointer**: Points to the actual data in memory
3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A)
4. **Strides**: How to navigate through memory to access elements
5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction
6. **Guaranteed Vector Stride**: The stride of those vectorizable elements
---
## The Memory Abstraction Hierarchy
CK Tile uses a three-layer abstraction to go from raw memory to structured tensors:
```
┌─────────────────────────────────────────────────────────────┐
│ Layer 3: TENSOR VIEW │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ • Logical multi-dimensional structure │ │
│ │ • Shape: (M, K) = (256, 32) │ │
│ │ • Strides: (32, 1) for row-major layout │ │
│ │ • Provides: operator[], coordinate-based access │ │
│ │ • Knows: How to map (i,j) → linear offset │ │
│ └─────────────────────────────────────────────────────────┘ │
│ ↓ wraps │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Layer 2: BUFFER VIEW │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ • Linear view of memory │ │ │
│ │ │ • Pointer: p_data_ → device memory │ │ │
│ │ │ • Size: Total number of elements │ │ │
│ │ │ • Address space: global/LDS/generic │ │ │
│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ ↓ wraps │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Layer 1: RAW PHYSICAL MEMORY │ │
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │
│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │
│ │ ↑ │ │
│ │ p_a (raw pointer from hipMalloc) │ │
│ └─────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
---
## Deep Dive: `make_naive_tensor_view`
Let's break down the function call for matrix A:
```cpp
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
p_a, // Raw pointer to device memory
make_tuple(M, K), // Shape: (256, 32)
make_tuple(stride_a, 1), // Strides: (32, 1) - row-major
number<8>{}, // Guaranteed vector length
number<1>{} // Guaranteed vector stride
);
```
### Function Signature:
```cpp
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* __restrict__ p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
// Step 1: Create tensor descriptor (shape + stride information)
auto desc = make_naive_tensor_descriptor(lengths,
strides,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
// Step 2: Create buffer view (pointer + size + address space)
auto buffer_view =
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
// Step 3: Combine into tensor view
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
```
---
## Parameter Breakdown
### 1. **Template Parameter: `address_space_enum::global`**
Specifies where the memory lives:
- `global`: GPU global memory (DRAM) - slowest but largest
- `lds`: Local Data Share (shared memory) - fast, limited size
- `generic`: Generic address space
- `vgpr`: Vector General Purpose Registers - fastest, smallest
In our case, `global` means the data is in GPU DRAM.
### 2. **`p_a` - Raw Pointer**
The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data.
### 3. **`make_tuple(M, K)` - Shape/Lengths**
Defines the logical dimensions of the tensor:
- For matrix A: `(256, 32)` means 256 rows, 32 columns
- This is the **logical view**, independent of how data is physically laid out
### 4. **`make_tuple(stride_a, 1)` - Strides**
Defines how to navigate through memory:
- **Stride for dimension 0 (rows)**: `stride_a = K = 32`
- To move to the next row, skip 32 elements
- **Stride for dimension 1 (columns)**: `1`
- To move to the next column, skip 1 element
**Row-major layout example:**
```
Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...]
↑ ↑
Row 0 starts here Row 1 starts here (offset = 32)
To access element A[i][j]:
offset = i * stride_a + j * 1
= i * 32 + j
```
### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length**
This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."**
#### Why is this important?
Modern GPUs can load multiple elements in one instruction (vectorized loads):
- `float4`: Load 4 floats at once
- `float8`: Load 8 floats at once (if supported)
By specifying `number<8>{}`, we're telling the system:
- "You can safely use vector loads of up to 8 elements"
- "The memory alignment and layout support this"
**Example:**
```cpp
// Without vectorization (slow):
for (int j = 0; j < 8; j++) {
data[j] = memory[offset + j]; // 8 separate loads
}
// With vectorization (fast):
float8 vec = *reinterpret_cast<float8*>(&memory[offset]); // 1 load!
```
### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride**
This specifies the **stride between consecutive vectorizable elements** in the last dimension.
- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)"
- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively
**Why does this matter?**
For efficient vectorized loads, elements must be:
1. **Contiguous** (stride = 1) ✓
2. **Aligned** properly in memory
3. **Within the same cache line** (ideally)
If the stride were `2`, it would mean:
```
A[i][0] is at offset 0
A[i][1] is at offset 2 (not 1!)
A[i][2] is at offset 4
```
This would prevent efficient vectorization.
---
## What is a Buffer View?
A **buffer view** is the middle layer between raw memory and tensor view. It provides:
### Core Responsibilities:
1. **Memory Management**
- Holds the raw pointer: `T* p_data_`
- Tracks buffer size: `BufferSizeType buffer_size_`
- Knows the address space: `global`, `lds`, etc.
2. **Vectorized Access**
```cpp
template <typename VectorType>
CK_TILE_DEVICE VectorType get(index_t offset);
```
- Provides efficient vector loads/stores
- Handles alignment requirements
3. **Bounds Checking** (optional)
```cpp
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto get(index_t i, index_t linear_offset);
```
- Can optionally check if access is within bounds
- Returns invalid value (default 0) for out-of-bounds access
4. **Address Space Awareness**
- Uses different load/store instructions based on address space
- Global memory: `global_load`, `global_store`
- LDS: `ds_read`, `ds_write`
### Buffer View Structure:
```cpp
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
struct buffer_view
{
T* p_data_; // Raw pointer
BufferSizeType buffer_size_; // Total elements
remove_cvref_t<T> invalid_element_value_; // Value for OOB access
// Access operators
const T& operator[](index_t i) const; // Read
T& operator()(index_t i); // Write
// Vectorized access
template <typename VectorType>
VectorType get(index_t offset);
};
```
---
## Visual Example: Matrix A Memory Layout
Let's visualize how matrix A (256×32, fp16) is organized:
### Raw Physical Memory (Linear):
```
GPU DRAM Address Space:
┌─────────────────────────────────────────────────────────────────┐
│ Byte 0 │
│ ↓ │
│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │
│ ↑ ↑ │
│ Row 0 (32 elements) Row 1 (32 elements) │
│ │
│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │
└─────────────────────────────────────────────────────────────────┘
p_a (raw pointer)
```
### Buffer View Layer:
```
buffer_view<address_space_enum::global, fp16_t, ...>
┌─────────────────────────────────────────────────────────────────┐
│ p_data_ = p_a │
│ buffer_size_ = 256 × 32 = 8,192 elements │
│ address_space = global (DRAM) │
│ │
│ Provides: │
│ • Linear indexing: buffer_view[i] → element at offset i │
│ • Vectorized loads: get<float4>(offset) → load 4 fp16s at once│
│ • Bounds checking: is offset < buffer_size_? │
└─────────────────────────────────────────────────────────────────┘
```
### Tensor View Layer:
```
tensor_view<buffer_view, tensor_descriptor>
┌─────────────────────────────────────────────────────────────────┐
│ Shape: (256, 32) │
│ Strides: (32, 1) │
│ Guaranteed vector length: 8 │
│ Guaranteed vector stride: 1 │
│ │
│ Logical 2D View: │
│ Col: 0 1 2 ... 31 │
│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│
│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │
│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │
│ ... │
│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │
│ │
│ Provides: │
│ • Multi-dimensional indexing: A[i][j] │
│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │
│ • Tile window creation: Extract sub-tensors │
└─────────────────────────────────────────────────────────────────┘
```
---
## Complete Flow: Raw Memory → Tensor View
Let's trace the complete transformation for matrix A:
### Step 1: Kernel Launch (Host Side)
```cpp
// On host: Allocate device memory
hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer
// Launch kernel
kernel<<<grid, block>>>(p_a, p_b, p_c, M, N, K, ...);
```
### Step 2: Inside Kernel (Device Side)
```cpp
// Receive raw pointer
const fp16_t* p_a; // Points to GPU DRAM
// Step 2a: Create tensor descriptor
auto desc = make_naive_tensor_descriptor(
make_tuple(256, 32), // Shape
make_tuple(32, 1), // Strides
number<8>{}, // Vector length
number<1>{} // Vector stride
);
// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8"
// Step 2b: Create buffer view
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_a, // Raw pointer
256 * 32 // Total elements
);
// buffer_view now wraps p_a with size and address space info
// Step 2c: Create tensor view
auto a_dram = tensor_view{buffer_view, desc};
// a_dram now provides structured, multi-dimensional access to p_a
```
### Step 3: Using the Tensor View
```cpp
// Access element A[i][j]
auto value = a_dram[make_tuple(i, j)];
// Create a tile window (sub-tensor)
auto tile = make_tile_window(
a_dram,
make_tuple(16, 16), // 16×16 tile
make_tuple(0, 0) // Starting at origin
);
// Load tile into registers with vectorization
auto tile_data = load_tile(tile); // Uses vector loads internally!
```
---
## Why This Abstraction?
### Benefits:
1. **Type Safety**: Can't accidentally access wrong dimensions
2. **Performance**: Compiler knows about vectorization opportunities
3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers)
4. **Maintainability**: Logical structure separate from physical layout
5. **Optimization**: Guaranteed vector properties enable aggressive optimizations
### Example: Without Tensor Views (Manual Indexing)
```cpp
// Ugly, error-prone, hard to optimize:
for (int i = 0; i < 16; i++) {
for (int j = 0; j < 16; j++) {
float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j];
// Hope the compiler vectorizes this? 🤞
}
}
```
### Example: With Tensor Views (Clean, Optimized)
```cpp
// Clean, safe, automatically vectorized:
auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin);
auto tile_data = load_tile(tile); // Vectorized loads guaranteed!
```
---
## Summary
The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction:
1. **Raw Memory**: Linear array of bytes in GPU DRAM
2. **Buffer View**: Adds size, address space, and vectorized access
3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing
This abstraction enables:
- ✅ Clean, readable code
- ✅ Type-safe multi-dimensional access
- ✅ Automatic vectorization
- ✅ Flexible memory space handling
- ✅ Efficient tile-based computation
The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation!

View File

@@ -0,0 +1,150 @@
# CK Tile Practice GEMM Example
This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system.
## CK Tile API Structure
In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**:
1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices
2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads
3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue)
## Overview
This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing:
- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch
- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM
- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory
- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation
## Problem Setup and Data Flow
### Problem Size Configuration
We set the problem size using the M, N and K variables:
```cpp
ck_tile::index_t M = 1024; // Number of rows in A and C
ck_tile::index_t N = 512; // Number of columns in B and C
ck_tile::index_t K = 256; // Number of columns in A, rows in B
```
### Host Matrix Creation
Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory:
```cpp
// Host tensors with proper strides
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides); // M × K
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides); // N × K
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides); // M × N
// Initialize with random data
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
// Allocate device memory and transfer data
ck_tile::DeviceMem a_device(a_host);
a_device.ToDevice(a_host.data());
```
### PracticeGemmShape Configuration
A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile:
```cpp
using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block
using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave
```
- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix.
- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile.
- BlockTiles are further subdivided into WarpTiles.
- WarpTiles over A and B similarly work together to calculate the WarpTile of C.
### Problem and Policy Composition
```cpp
// A Problem is composed from Shape and info about the data
using PracticeGemmHostProblem = ck_tile::
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
// A Policy is created describing data-to-thread mapping
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
// A Kernel is then composed of Problem and Policy
using gemm_kernel = ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
```
### Kernel Launch
`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`:
```cpp
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, 0, 0, 1},
ck_tile::make_kernel<kBlockSize, kBlockPerCU>(
gemm_kernel{}, // Kernel composed of Problem + Policy
kGridSize, // Grid dimensions
kBlockSize, // Block dimensions
0, // Dynamic shared memory
// Kernel arguments: device buffers and problem dimensions
a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(),
M, N, K, stride_a, stride_b, stride_c));
```
### Result Verification
The results from the kernel are compared with results from CPU based computation function:
```cpp
// CPU reference implementation
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
// Device results
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
// Verify correctness
bool pass = ck_tile::check_err(c_host_dev, c_host_ref);
```
### Runtime Flow
The main program (`practice_gemm.cpp`) is the entry point for the runtime flow:
```cpp
int main()
{
// 1. Define data types and problem sizes
using ADataType = ck_tile::half_t;
ck_tile::index_t M = 2048, N = 1024, K = 512;
// 2. Create host tensors and initialize
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
// 3. Allocate device memory and transfer data
ck_tile::DeviceMem a_device(a_host);
// 4. Configure tile shapes
using BlockTile = ck_tile::sequence<256, 128, 32>;
using WaveTile = ck_tile::sequence<16, 16, 16>;
// 5. Launch kernel
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
float ave_time = ck_tile::launch_kernel(/*...*/);
// 6. Verify results
bool pass = verify_results(a_host, b_host, c_host);
// 7. Print performance metrics
print_performance_metrics(ave_time, M, N, K);
}
```
## Building and Running
```bash
# From composable_kernel root directory
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_practice_gemm -j
# Run with sample sizes
./bin/tile_example_practice_gemm
```
This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework.

View File

@@ -0,0 +1,506 @@
# Practice GEMM: Step-by-Step Code Walkthrough
This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API.
## Overview
We'll implement `C = A × B` where:
- `A` is an `M × K` matrix
- `B` is an `N × K` matrix (note: transposed layout)
- `C` is an `M × N` matrix
The implementation uses a hierarchical tiling strategy with two levels:
1. **Block Tiles**: Processed by thread blocks
2. **Wave Tiles**: Processed by warps (wavefronts) within blocks
---
## Step 1: Define Data Types
```cpp
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = float;
using AccDataType = float;
```
**What's happening:**
- We use `half_t` (FP16) for input matrices A and B.
- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy
- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity
---
## Step 2: Define Problem Size
```cpp
ck_tile::index_t M = 512;
ck_tile::index_t N = 256;
ck_tile::index_t K = 64;
ck_tile::index_t verification = 1;
ck_tile::index_t stride_a = K;
ck_tile::index_t stride_b = K;
ck_tile::index_t stride_c = N;
```
**What's happening:**
- `M = 512`: Number of rows in A and C
- `N = 256`: Number of columns in B and C
- `K = 64`: Inner dimension (columns of A, rows of B)
- Strides define memory layout (row-major for A and C, transposed for B)
**Memory Layout:**
```
Matrix A (M×K): Matrix B (N×K): Matrix C (M×N):
[512 rows] [256 rows] [512 rows]
[64 cols] [64 cols] [256 cols]
stride = K stride = K stride = N
```
---
## Step 3: Create Host Tensors
```cpp
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
```
**What's happening:**
- We create three tensors on the host (CPU) memory
- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`)
- `HostTensor` is a CK Tile utility class that manages CPU memory
**Stride explanation:**
- For A: `stride_a = K` means moving to the next row requires skipping K elements
- For B: `stride_b = K` means B is stored in transposed format
- For C: `stride_c = N` means row-major layout
---
## Step 4: Initialize Tensors with Random Data
```cpp
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
c_host.SetZero();
```
**What's happening:**
- A and B are filled with random values in the range [-5.0, 5.0]
- C is initialized to zero (will store the output)
**Optional: Print Tensor Contents**
```cpp
// Commented out in the code, but available for debugging:
// a_host.print_first_n(10); // Print first 10 elements of A
```
The `print_first_n()` helper function can display tensor contents for debugging purposes.
---
## Step 5: Allocate Device Memory and Transfer Data
```cpp
ck_tile::DeviceMem a_device(a_host);
ck_tile::DeviceMem b_device(b_host);
ck_tile::DeviceMem c_device(c_host);
```
**What's happening:**
- `DeviceMem` allocates GPU memory matching the size of host tensors
- The constructor **automatically transfers data from host to device**
- This is a convenience wrapper around `hipMalloc` and `hipMemcpy`
**Memory Flow:**
```
CPU (Host) GPU (Device)
┌─────────┐ ┌─────────┐
│ a_host │ ────────> │a_device │
│ b_host │ ────────> │b_device │
│ c_host │ ────────> │c_device │
└─────────┘ └─────────┘
```
---
## Step 6: Configure Hierarchical Tiling
```cpp
using BlockTile = ck_tile::sequence<256, 128, 32>;
using WaveTile = ck_tile::sequence<16, 16, 16>;
```
**What's happening:**
- We define a two-level tiling hierarchy for the GEMM computation
### Block Tile (256 × 128 × 32)
- **256**: M dimension per block (rows of A and C)
- **128**: N dimension per block (columns of B and C)
- **32**: K dimension per block (inner dimension)
- Each block tile is processed by one **thread block** (256 threads)
### Wave Tile (16 × 16 × 16)
- **16 × 16**: Output tile dimensions (M × N) per warp iteration
- **16**: K dimension per warp iteration
- Each wave tile is processed by one **warp** (64 threads on AMD GPUs)
**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile
**Important Note:**
In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem.
### Tiling Visualization:
#### Matrix A (M × K = 256 × 32):
```
┌─────────────────────────────────────┐
│ One Block Tile (256 × 32) │
│ ┌────┬────┐ │
│ │16×│16× │ ← Wave tiles (16×16) │
│ │ 16│ 16 │ in M×K space │
│ ├────┼────┤ │
│ │ │ │ │
│ ├────┼────┤ │
│ │ .. │ .. │ 16 tiles in M │
│ ├────┼────┤ 2 tiles in K │
│ │ │ │ │
│ └────┴────┘ │
│ │
└─────────────────────────────────────┘
```
#### Matrix B (N × K = 128 × 32):
```
┌──────────────────────────────┐
│ One Block Tile (128 × 32) │
│ ┌────┬────┐ │
│ │16×│16× │ ← Wave tiles │
│ │ 16│ 16 │ (16×16) │
│ ├────┼────┤ │
│ │ │ │ │
│ ├────┼────┤ 8 tiles in N │
│ │ .. │ .. │ 2 tiles in K │
│ ├────┼────┤ │
│ │ │ │ │
│ └────┴────┘ │
└──────────────────────────────┘
```
#### Matrix C (M × N = 256 × 128) - Output:
```
┌─────────────────────────────────────────────────┐
│ One Block Tile (256 × 128) │
│ │
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
│ │16× │ │ │ │ │ │ │ │ │
│ │ 16 │ │ │ │ │ │ │ │ │
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
│ │ │ │ │ │ │ │ │ │ │
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
│ │ │ │ │ │ │ │ │ │ │
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
│ │ │ │ │ │ │ │ │ │ │
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
│ │
│ 16 wave tiles in M direction │
│ 8 wave tiles in N direction │
│ Total: 128 wave tiles (16×16 each) │
└─────────────────────────────────────────────────┘
```
#### How Wave Tiles Combine (C = A × B):
```
Matrix A Matrix B (stored transposed N×K) Matrix C
(256×32) (128×32) (256×128)
Row of A tiles: Row of B tiles: One wave tile in C:
┌────┬────┐ ┌────┬────┐ ┌────┐
│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16)
└────┴────┘ └────┴────┘ └────┘
16×16 each 16×16 each
Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ
↑ ↑
K=0..15 K=16..31
Each wave tile in C is computed by:
- Taking one row of wave tiles from A (2 tiles along K)
- Taking one row of wave tiles from B (2 tiles along K)
Note: B is stored transposed (N×K), so a "row" in storage corresponds
to a "column" in the logical B^T matrix used in computation
- Performing dot product: Σ(A_k × B_k^T) for k=0,1
```
**Key Insight:**
- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B
- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory
- This is the fundamental operation repeated across all 128 wave tiles in C
- Each warp computes one wave tile using MFMA instructions
---
## Step 7: Create Shape, Problem, and Policy Structs
```cpp
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
using PracticeGemmHostProblem = ck_tile::
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
```
**What's happening:**
### 1. **Shape Struct**
Encapsulates all tile shape information (BlockTile and WaveTile dimensions).
### 2. **Problem Struct**
Holds complete problem description:
- Data types (ADataType, BDataType, CDataType, AccDataType)
- Shape information (BlockTile, WaveTile)
In more complex examples, this would also include:
- Data layouts (row-major, column-major)
- Mathematical operations (e.g., transposed GEMM)
### 3. **Policy Struct**
Describes data movement and thread-to-data mapping:
- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions
- In more complex kernels, includes:
- DRAM access patterns
- LDS (Local Data Share) usage strategies
- Thread distribution within blocks
**CK Tile Design Pattern:**
```
Kernel = Problem + Policy + Epilogue
↑ ↑ ↑
(What) (How) (Post-processing)
```
---
## Step 8: Calculate Grid and Block Dimensions
```cpp
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
std::cout << "kGridSize: " << kGridSize << std::endl;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCU = 1;
```
**What's happening:**
### Grid Size Calculation
```cpp
kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N)
= ceil(512 / 256) × ceil(256 / 128)
= 2 × 2
= 4 thread blocks
```
Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction).
### Block Configuration
- `kBlockSize = 256`: Each thread block has 256 threads
- 256 threads / 64 threads per warp = **4 warps per block**
- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity)
**Thread Hierarchy:**
```
GPU
└── 1 Thread Block (Grid)
└── 256 Threads
├── Warp 0 (threads 0-63)
├── Warp 1 (threads 64-127)
├── Warp 2 (threads 128-191)
└── Warp 3 (threads 192-255)
```
---
## Step 9: Create and Launch the Kernel
```cpp
using gemm_kernel =
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, 0, 0, 1},
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
kGridSize,
kBlockSize,
0,
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
M,
N,
K,
stride_a,
stride_b,
stride_c));
```
**What's happening:**
### 1. Kernel Composition
```cpp
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
```
The kernel is composed from Problem and Policy structs, following the CK Tile design pattern.
### 2. Kernel Launch
`launch_kernel()` is a CK Tile utility that:
- Launches the GPU kernel using HIP runtime
- Measures execution time
- Returns average execution time in milliseconds
### 3. Launch Parameters
- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled
- **Grid size**: `kGridSize = 1` - number of thread blocks
- **Block size**: `kBlockSize = 256` - threads per block
- **Shared memory**: `0` - no dynamic shared memory in this example
- **Kernel arguments**: Device pointers and problem dimensions
### 4. Kernel Execution Flow
```
launch_kernel() calls gemm_kernel.operator()()
PracticeGemmKernel::operator()
Creates tensor views over device memory
Calls block-level pipeline
Block pipeline calls warp-level pipeline
Warp pipeline calls MFMA instructions
Results written back to C matrix
```
---
## Step 10: Verify Results
```cpp
auto pass = true;
if(verification)
{
// Reference gemm on CPU
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_host, b_host, c_host_ref);
// Copy GPU results back to host
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
c_device.FromDevice(c_host_dev.mData.data());
// Compare results
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
```
**What's happening:**
### 1. CPU Reference Implementation
```cpp
reference_basic_gemm<...>(a_host, b_host, c_host_ref);
```
Computes GEMM on CPU using a simple nested loop implementation (ground truth).
### 2. Copy GPU Results to Host
```cpp
c_device.FromDevice(c_host_dev.mData.data());
```
Transfers the computed result from GPU memory back to CPU for comparison.
### 3. Error Checking
```cpp
ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
```
Compares GPU and CPU results element-wise with tolerance:
- **Relative error**: 1e-3 (0.1%)
- **Absolute error**: 1e-3
**Verification Flow:**
```
CPU GPU
┌─────────┐ ┌─────────┐
│ a_host │ ────────> │a_device │
│ b_host │ ────────> │b_device │
└─────────┘ └─────────┘
│ │
↓ ↓
reference_gemm() GPU kernel
│ │
↓ ↓
┌──────────┐ ┌──────────┐
│c_host_ref│ │c_device │
└──────────┘ └──────────┘
│ │
│ ↓
│ FromDevice()
│ │
↓ ↓
└────> check_err() <───┘
Pass/Fail
```
---
## Complete Execution Flow Summary
```
1. Define data types (FP16 inputs, FP32 output)
2. Set problem size (M=256, N=128, K=32)
3. Create host tensors and initialize with random data
4. Allocate device memory and transfer data (CPU → GPU)
5. Configure hierarchical tiling (BlockTile, WaveTile)
6. Create Shape, Problem, and Policy structs
7. Calculate grid/block dimensions (1 block, 256 threads)
8. Compose and launch kernel (Problem + Policy)
9. Execute GEMM on GPU
│ ├─ Block-level pipeline
│ ├─ Warp-level pipeline
│ └─ MFMA instructions
10. Verify results (compare GPU vs CPU reference)
11. Calculate and print performance metrics
12. Return success/failure
```
---

View File

@@ -0,0 +1,165 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = PracticeGemmBlockPolicy>
struct PracticeGemmBlockPipelineAGmemBGmemCreg
{
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
using AccDataType = typename Problem::AccDataType;
using BlockTile = typename Problem::Shape::BlockTile;
using WaveTile = typename Problem::Shape::WaveTile;
static constexpr index_t MPerBlock = BlockTile::at(number<0>{});
static constexpr index_t NPerBlock = BlockTile::at(number<1>{});
static constexpr index_t KPerBlock = BlockTile::at(number<2>{});
static constexpr index_t MPerWave = WaveTile::at(number<0>{});
static constexpr index_t NPerWave = WaveTile::at(number<1>{});
static constexpr index_t KPerWave = WaveTile::at(number<2>{});
using BlockGemm =
remove_cvref_t<decltype(Policy::template GetPracticeWaveGemmPipeline<Problem>())>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// -----------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
// -------------------------------------------------------------------------------------
// Gemm pipeline start
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// non-prefetch
index_t iCounter = num_loop;
while(iCounter > 0)
{
a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers
b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS
store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers
block_sync_lds();
iCounter--;
}
return c_block_tile;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,135 @@
#pragma once
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp"
#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename AccDataType_,
typename Shape_>
struct PracticeGemmBlockPipelineProblem
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
using AccDataType = AccDataType_;
using Shape = Shape_;
};
struct PracticeGemmBlockPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline()
{
return PracticeGemmWarpPipelineASmemBSmemCreg<Problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
constexpr index_t kKPack = 8;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
constexpr index_t kKPack = 8;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
constexpr index_t kMWarp = BlockGemm::MWarp;
constexpr index_t kNWarp = BlockGemm::NWarp;
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
constexpr index_t kMWarp = BlockGemm::MWarp;
constexpr index_t kNWarp = BlockGemm::NWarp;
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,92 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
struct PracticeGemmHostPipeline
{
using ADataType = typename Problem_::ADataType;
using BDataType = typename Problem_::BDataType;
using CDataType = typename Problem_::CDataType;
using AccDataType = typename Problem_::AccDataType;
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockTile = typename Problem::Shape::BlockTile;
using WaveTile = typename Problem::Shape::WaveTile;
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
const BDRAMTensorView& b_dram,
CDRAMTensorView& c_dram_ref) const
{
// Size of the entire problem
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
// Size of the block tile
const auto MPerBlock = BlockTile::at(number<0>{});
const auto NPerBlock = BlockTile::at(number<1>{});
const auto KPerBlock = BlockTile::at(number<2>{});
// Number of block tile in the N direction to cover C (resultant) matrix
const auto num_tile_n = integer_divide_ceil(N, NPerBlock);
// Number of block tile in the M direction to cover C (resultant) matrix
const auto num_tile_m = integer_divide_ceil(M, MPerBlock);
// if(get_thread_id() == 0 && get_block_id() == 0)
// {
// printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n);
// printf("total number of tiles: %d\n", num_tile_m * num_tile_n);
// }
// Get block id
const auto id_block =
get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1
// Map block id to tile id
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
const auto tile_id = block2tile(id_block);
const auto tile_id_m = tile_id.at(number<0>{});
const auto tile_id_n = tile_id.at(number<1>{});
// if(get_thread_id() == 0 && get_block_id() == 15)
// {
// printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n);
// }
const auto tile_origin_m = tile_id_m * MPerBlock;
const auto tile_origin_n = tile_id_n * NPerBlock;
// create a tile window over dram for A and B
const auto a_block_window = make_tile_window(
a_dram, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {tile_origin_m, 0});
const auto b_block_window = make_tile_window(
b_dram, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {tile_origin_n, 0});
constexpr auto block_gemm_pipeline =
Policy::template GetPracticeGemmBlockPipeline<Problem>();
int num_loops_k = integer_divide_ceil(K, KPerBlock);
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
const auto c_block_tile =
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
auto c_window = make_tile_window(c_dram,
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
{tile_origin_m, tile_origin_n});
store_tile(c_window, c_block_tile);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,51 @@
#pragma once
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp"
#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename AccDataType_,
typename Shape_>
struct PracticeGemmHostProblem
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
using AccDataType = AccDataType_;
using Shape = remove_cvref_t<Shape_>;
};
struct PracticeGemmHostPolicy
{
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
return [unmerge](index_t block_id) {
multi_index<2> unmerged;
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline()
{
using PracticeGemmBlockPipelineProblem_ =
PracticeGemmBlockPipelineProblem<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
typename Problem::AccDataType,
typename Problem::Shape>;
return PracticeGemmBlockPipelineAGmemBGmemCreg<PracticeGemmBlockPipelineProblem_>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck_tile/host.hpp"
#include "practice_gemm.hpp"
#include "reference_gemm.hpp"
int main()
{
// TODO: GemmTypeConfig
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = float;
using AccDataType = float;
// ArgParser
ck_tile::index_t M = 512;
ck_tile::index_t N = 256;
ck_tile::index_t K = 64;
ck_tile::index_t verification = 1;
ck_tile::index_t stride_a = K;
ck_tile::index_t stride_b = K;
ck_tile::index_t stride_c = N;
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
// tensors on host (cpu)
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
// initialize tensors
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
c_host.SetZero();
// Print the tensors using the new print_first_n member function
// std::cout << "Tensor A (first 10 elements): ";
// a_host.print_first_n(10);
// std::cout << std::endl;
// std::cout << "Tensor B (first 10 elements): ";
// b_host.print_first_n(10);
// std::cout << std::endl;
// std::cout << "Tensor C (first 10 elements): ";
// c_host.print_first_n(10);
// std::cout << std::endl;
// Create device tensors of same size as host tensors and copy data
ck_tile::DeviceMem a_device(a_host);
ck_tile::DeviceMem b_device(b_host);
ck_tile::DeviceMem c_device(c_host);
// TODO: BlockTileConfig
// constexpr ck_tile::index_t warpSize = 64;
constexpr ck_tile::index_t kBlockSize = 256;
using BlockTile = ck_tile::sequence<256, 128, 32>;
using WaveTile = ck_tile::sequence<16, 16, 16>;
std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl;
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
using PracticeGemmHostProblem = ck_tile::
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
std::cout << "kGridSize: " << kGridSize << std::endl;
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
std::cout << "kBlockSize: " << kBlockSize << std::endl;
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
using gemm_kernel =
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, 0, 0, 1},
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
kGridSize,
kBlockSize,
0,
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
M,
N,
K,
stride_a,
stride_b,
stride_c));
auto pass = true;
if(verification)
{
// reference gemm
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_host, b_host, c_host_ref);
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
c_device.FromDevice(c_host_dev.mData.data());
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !pass;
}

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp"
#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp"
namespace ck_tile {
template <typename BlockTile_, typename WaveTile_>
struct PracticeGemmShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using WaveTile = remove_cvref_t<WaveTile_>;
static constexpr index_t BlockTile_M = BlockTile::at(number<0>{});
static constexpr index_t BlockTile_N = BlockTile::at(number<1>{});
static constexpr index_t BlockTile_K = BlockTile::at(number<2>{});
static constexpr index_t WaveTile_M = WaveTile::at(number<0>{});
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
CK_TILE_HOST static std::string GetName()
{
// clang-format off
return concat('_', "practice_gemm_shape",
concat('x', BlockTile_M, BlockTile_N, BlockTile_K),
concat('x', WaveTile_M, WaveTile_N, WaveTile_K));
// clang-format on
}
};
template <typename Problem_, typename Policy_>
struct PracticeGemmKernel
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = 256;
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
const typename Problem::BDataType* p_b,
typename Problem::CDataType* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c) const
{
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
const ck_tile::HostTensor<BDataType>& b_n_k,
ck_tile::HostTensor<CDataType>& c_m_n)
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_n_k(n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
}
};
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1);
}

View File

@@ -0,0 +1,195 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = PracticeGemmWarpPolicy>
struct PracticeGemmWarpPipelineASmemBSmemCreg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using WaveGemmShape = remove_cvref_t<typename Problem::Shape>;
using WarpGemm = remove_cvref_t<
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
static constexpr index_t MWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
static constexpr index_t NWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
NPerBlock == WaveGemmShape::BlockTile_N &&
KPerBlock == WaveGemmShape::BlockTile_K,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
a_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// Warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
NPerBlock == WaveGemmShape::BlockTile_N &&
KPerBlock == WaveGemmShape::BlockTile_K,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
// Construct C-Block-Tensor
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCReg
// Default policy class should not be templated, put template on member functions instead
struct PracticeGemmWarpPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
constexpr index_t kMWarp = 4;
constexpr index_t kNWarp = 1;
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
}
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};
} // namespace ck_tile