mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038)
* feat: add tutorial folder with gemm tutorial * chore: move copy kernel from examples folder to tutorial * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: remove handdrawn images * docs: add write ups to explain the gemm kernel * docs: add about block level pipeline and static distributed tensors --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
589
tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
589
tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
|
||||
## Overview
|
||||
|
||||
The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates:
|
||||
1. **Data movement** from DRAM → Registers → LDS
|
||||
2. **GEMM computation** using data in LDS
|
||||
3. **Iteration** over the K dimension when needed
|
||||
|
||||
This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C.
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Problem and Policy
|
||||
|
||||
Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern:
|
||||
|
||||
### Problem: `PracticeGemmBlockPipelineProblem`
|
||||
Contains:
|
||||
- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType`
|
||||
- **Shape information**: `BlockTile` and `WaveTile` dimensions
|
||||
|
||||
### Policy: `PracticeGemmBlockPolicy`
|
||||
Contains strategies for:
|
||||
1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`)
|
||||
- Defines how 256 threads in a block map to elements of a block tile
|
||||
- Each thread knows which elements to load/store from DRAM to its registers
|
||||
- We'll cover tile distribution construction in detail later
|
||||
|
||||
2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`)
|
||||
- Describes how data is logically organized in Local Data Share (LDS)
|
||||
- Optimizes for bank conflict avoidance and efficient access patterns
|
||||
- We'll cover LDS descriptor construction in detail later
|
||||
|
||||
3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`)
|
||||
- Returns the warp-level GEMM implementation
|
||||
|
||||
---
|
||||
|
||||
## Inputs and Outputs
|
||||
|
||||
```cpp
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
```
|
||||
|
||||
### Inputs:
|
||||
- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock)
|
||||
- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock)
|
||||
- `num_loop`: Number of iterations along K dimension
|
||||
- `p_smem`: Pointer to shared memory (LDS)
|
||||
|
||||
### Output:
|
||||
- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs)
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Walkthrough
|
||||
|
||||
### Step 1: Create LDS Tensor Views
|
||||
|
||||
```cpp
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// B tile in LDS (placed after A in shared memory)
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We partition the shared memory (`p_smem`) into two regions: one for A, one for B
|
||||
- We create **tensor views** over these LDS regions using descriptors from the policy
|
||||
- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Shared Memory (LDS):
|
||||
┌─────────────────────┬─────────────────────┐
|
||||
│ A Block Tile │ B Block Tile │
|
||||
│ (256×32 fp16) │ (128×32 fp16) │
|
||||
└─────────────────────┴─────────────────────┘
|
||||
↑ ↑
|
||||
p_a_lds p_b_lds
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Create Tile Windows for Data Movement
|
||||
|
||||
We create **6 tile windows** for different purposes:
|
||||
|
||||
#### 2a. DRAM → Registers (Load from DRAM)
|
||||
|
||||
```cpp
|
||||
auto a_copy_dram_window = make_tile_window(
|
||||
a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>()); // ← Tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `a_copy_dram_window` is a `tile_window_with_static_distribution`
|
||||
- The **tile distribution** tells each thread which elements to load from DRAM
|
||||
- This window will **slide along the K dimension** in the loop
|
||||
|
||||
#### 2b. Registers → LDS (Store to LDS)
|
||||
|
||||
```cpp
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
{0, 0}, // Origin at (0, 0) in LDS
|
||||
a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Uses the **same tile distribution** as `a_copy_dram_window`
|
||||
- This ensures each thread stores to LDS in the same pattern it loaded from DRAM
|
||||
- Origin is always `{0, 0}` because LDS is reused for each K iteration
|
||||
|
||||
#### 2c. LDS → Registers (GEMM Input)
|
||||
|
||||
```cpp
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0}); // No tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- This is a `tile_window_with_static_lengths` (no explicit distribution)
|
||||
- Used as input to the warp-level GEMM
|
||||
- The warp GEMM will handle its own thread mapping internally
|
||||
|
||||
**Similar windows are created for B:**
|
||||
- `b_copy_dram_window`: Load B from DRAM
|
||||
- `b_copy_lds_window`: Store B to LDS
|
||||
- `b_lds_gemm_window`: Read B from LDS for GEMM
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Create Distributed Tensors (VGPRs)
|
||||
|
||||
```cpp
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile; // Per-thread registers for A
|
||||
BBlockTile b_block_tile; // Per-thread registers for B
|
||||
```
|
||||
|
||||
#### What is `make_static_distributed_tensor`?
|
||||
|
||||
**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**.
|
||||
|
||||
**Key Properties:**
|
||||
1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers
|
||||
2. **Compile-time sized**: Buffer size determined by tile distribution at compile time
|
||||
3. **Zero-overhead**: All indexing and layout transformations happen at compile time
|
||||
|
||||
**How it works:**
|
||||
|
||||
```cpp
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
// Calculate per-thread storage size from tile distribution
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
// Per-thread register array (VGPRs)
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
```
|
||||
|
||||
**The tile distribution defines:**
|
||||
- **Which elements each thread owns** in the tile
|
||||
- **How many elements** each thread stores (buffer size)
|
||||
- **How elements are laid out** in each thread's registers
|
||||
|
||||
**Concrete Example for 256×32 tile with 256 threads:**
|
||||
|
||||
```
|
||||
Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values)
|
||||
Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values)
|
||||
Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values)
|
||||
...
|
||||
Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values)
|
||||
```
|
||||
|
||||
**Collectively:**
|
||||
- All 256 threads together hold the **entire 256×32 tile** (8192 elements)
|
||||
- Each thread's buffer lives in its **own VGPRs**
|
||||
- No two threads own the same element
|
||||
|
||||
**Distributed Ownership Analogy:**
|
||||
Think of a tile as a **jigsaw puzzle**:
|
||||
- The **tile distribution** is the cutting pattern
|
||||
- Each **thread** gets one puzzle piece (its slice)
|
||||
- Each **`static_distributed_tensor`** is a box holding all pieces
|
||||
- Each thread's **`thread_buf_`** is its individual piece in its own registers
|
||||
|
||||
---
|
||||
|
||||
### Step 4: The GEMM Loop
|
||||
|
||||
```cpp
|
||||
// Initialize C accumulator to zero
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
index_t iCounter = num_loop; // Number of K iterations
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// 1. Load from DRAM to registers
|
||||
a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs
|
||||
b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs
|
||||
|
||||
// 2. Move windows for next iteration
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32)
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32)
|
||||
|
||||
// 3. Store from registers to LDS
|
||||
store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS
|
||||
|
||||
// 4. Synchronize threads (ensure all data is in LDS)
|
||||
block_sync_lds();
|
||||
|
||||
// 5. Compute GEMM using data in LDS
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// 6. Synchronize threads (before overwriting LDS in next iteration)
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile; // Return accumulated result in registers
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Detailed Loop Breakdown
|
||||
|
||||
### Phase 1: Load (DRAM → VGPRs)
|
||||
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution)
|
||||
2. Data is loaded into **per-thread registers** (VGPRs)
|
||||
3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once)
|
||||
|
||||
**Example for Thread 0:**
|
||||
```
|
||||
Thread 0 loads:
|
||||
A[0,0:7] (8 fp16 values, one vector load)
|
||||
A[1,0:7] (8 fp16 values, one vector load)
|
||||
...
|
||||
```
|
||||
|
||||
### Phase 2: Move Windows
|
||||
|
||||
```cpp
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example)
|
||||
- This prepares for the next K iteration
|
||||
- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ...
|
||||
|
||||
**Visualization for Problem Size 512×256×64:**
|
||||
```
|
||||
Matrix A (512×64):
|
||||
┌─────────────────────────────────────┐
|
||||
│ Block 0: rows 0-255 │
|
||||
│ ┌──────────┬──────────┐ │
|
||||
│ │ K=0:31 │ K=32:63 │ │ ← Window slides right
|
||||
│ │ Iter 0 │ Iter 1 │ │
|
||||
│ └──────────┴──────────┘ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Phase 3: Store (VGPRs → LDS)
|
||||
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread writes **its elements** from registers to LDS
|
||||
2. Uses the **same distribution** as the DRAM load
|
||||
3. Data is now in **shared memory**, accessible to all threads in the block
|
||||
|
||||
**Why this step?**
|
||||
- GEMM computation needs **all threads** to access **all data**
|
||||
- Registers are per-thread; LDS is shared across the block
|
||||
- LDS acts as a "staging area" for collaborative computation
|
||||
|
||||
### Phase 4: Synchronize
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- All threads in the block **wait** until everyone has finished storing to LDS
|
||||
- Ensures no thread starts reading from LDS before all writes are complete
|
||||
- Critical for correctness!
|
||||
|
||||
### Phase 5: GEMM Computation
|
||||
|
||||
```cpp
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. The warp-level GEMM reads data from LDS
|
||||
2. Performs matrix multiplication using MFMA instructions
|
||||
3. Accumulates results into `c_block_tile` (in registers)
|
||||
|
||||
**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results.
|
||||
|
||||
### Phase 6: Synchronize Again
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- Ensures all threads have finished reading from LDS
|
||||
- Safe to overwrite LDS in the next iteration
|
||||
|
||||
---
|
||||
|
||||
## Memory Flow Diagram
|
||||
|
||||
```
|
||||
Iteration 0 (K=0:31):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 0:31] │ │ thread) │ │ │
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (VGPRs) │
|
||||
└──────────┘
|
||||
|
||||
Iteration 1 (K=32:63):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 32:63] │ │ thread) │ │ (reused)│
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (accum.) │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example: Problem Size 512×256×64
|
||||
|
||||
### Block 0 Computation
|
||||
|
||||
**Input:**
|
||||
- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially
|
||||
- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed)
|
||||
- `num_loop`: 2 (since K=64, KPerBlock=32)
|
||||
|
||||
**Iteration 0:**
|
||||
1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63]
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T`
|
||||
|
||||
**Iteration 1:**
|
||||
1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends)
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T`
|
||||
|
||||
**Output:**
|
||||
- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. Tile Distribution
|
||||
- **Maps threads to data elements** for load/store operations
|
||||
- Each thread knows exactly which elements it's responsible for
|
||||
- Enables **parallel, vectorized** memory access
|
||||
- **Same distribution** used for DRAM load and LDS store
|
||||
|
||||
### 2. Static Distributed Tensor
|
||||
- **Per-thread register storage** (VGPRs)
|
||||
- Each thread owns a **different slice** of the tile
|
||||
- **Compile-time sized** for zero-overhead abstraction
|
||||
- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile`
|
||||
|
||||
### 3. Tile Window Movement
|
||||
- Windows **slide** over larger tensors
|
||||
- Enables iteration over the K dimension
|
||||
- `move_tile_window(window, step)` updates the origin
|
||||
|
||||
### 4. LDS as Staging Area
|
||||
- **Shared memory** accessible to all threads in a block
|
||||
- Required because GEMM needs all threads to access all data
|
||||
- **Reused** across K iterations (same LDS buffer)
|
||||
|
||||
### 5. Synchronization
|
||||
- `block_sync_lds()` ensures memory consistency
|
||||
- **Before GEMM**: All stores to LDS are complete
|
||||
- **After GEMM**: All reads from LDS are complete
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `static_distributed_tensor` Mechanics
|
||||
|
||||
### How Tile Distribution Creates Per-Thread Storage
|
||||
|
||||
When you call:
|
||||
```cpp
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<fp16_t>(ABlockTileDistr{}));
|
||||
ABlockTile a_block_tile;
|
||||
```
|
||||
|
||||
**Step 1: Extract Thread Tensor Descriptor**
|
||||
|
||||
The tile distribution contains a `ys_to_d_descriptor` that maps:
|
||||
- **Y dimensions** (logical tile coordinates, e.g., M, K)
|
||||
- **D dimension** (per-thread register index, linearized)
|
||||
|
||||
```cpp
|
||||
using ThreadTensorDesc =
|
||||
decltype(StaticTileDistribution{}.get_ys_to_d_descriptor());
|
||||
```
|
||||
|
||||
**Step 2: Calculate Per-Thread Buffer Size**
|
||||
|
||||
```cpp
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
- 256×32 tile distributed across 256 threads
|
||||
- Each thread owns 32 elements (one row)
|
||||
- `thread_buffer_size = 32` (for PackedSize=1)
|
||||
|
||||
**Step 3: Allocate Thread Buffer**
|
||||
|
||||
```cpp
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
```
|
||||
|
||||
This is essentially:
|
||||
```cpp
|
||||
fp16_t data[32]; // Per-thread register array (VGPRs)
|
||||
```
|
||||
|
||||
### Usage in Load/Store Operations
|
||||
|
||||
**Load from DRAM:**
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread queries the tile distribution: "Which elements do I own?"
|
||||
2. Thread 0 learns it owns A[0,0:31]
|
||||
3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]`
|
||||
4. All 256 threads do this **in parallel**
|
||||
|
||||
**Store to LDS:**
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread reads from its `a_block_tile.thread_buf_`
|
||||
2. Thread 0 writes A[0,0:31] from its registers to LDS
|
||||
3. All 256 threads do this **in parallel**
|
||||
4. After `block_sync_lds()`, the entire tile is in shared LDS
|
||||
|
||||
### Distributed Indexing
|
||||
|
||||
The `static_distributed_tensor` supports compile-time indexing:
|
||||
|
||||
```cpp
|
||||
// Access using distributed indices
|
||||
auto value = a_block_tile(tile_distributed_index<i, j>{});
|
||||
```
|
||||
|
||||
Internally:
|
||||
1. Convert distributed index → Y index (logical tile coordinates)
|
||||
2. Calculate buffer offset using `ThreadTensorDesc`
|
||||
3. Access `thread_buf_[offset]`
|
||||
|
||||
All of this happens **at compile time** with zero runtime overhead!
|
||||
|
||||
### Why This Design?
|
||||
|
||||
**Benefits:**
|
||||
1. **Parallel Memory Access**: All threads load/store simultaneously
|
||||
2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once)
|
||||
3. **Zero Overhead**: All indexing resolved at compile time
|
||||
4. **Type Safety**: Distribution mismatch caught at compile time
|
||||
5. **Register Pressure**: Compiler knows exact VGPR usage
|
||||
|
||||
**Trade-offs:**
|
||||
- Requires compile-time tile sizes
|
||||
- Distribution must be static
|
||||
- More complex type system
|
||||
|
||||
### Memory Hierarchy Summary
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) │
|
||||
│ Full matrices A, B, C │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ load_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Per-Thread Registers) │
|
||||
│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │
|
||||
│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │
|
||||
│ ... │
|
||||
│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │
|
||||
│ │
|
||||
│ ← static_distributed_tensor manages this distribution │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ store_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) │
|
||||
│ Entire block tile (256×32) │
|
||||
│ Accessible to all threads in block │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time.
|
||||
|
||||
|
||||
|
||||
7
tutorial/ck_tile/01_naive_gemm/CMakeLists.txt
Normal file
7
tutorial/ck_tile/01_naive_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp)
|
||||
|
||||
target_compile_options(tile_tutorial_naive_gemm PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
add_dependencies(tutorials tile_tutorial_naive_gemm)
|
||||
618
tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
618
tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,618 @@
|
||||
# Host-Level Pipeline: Orchestrating Block-Level GEMM
|
||||
|
||||
This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation.
|
||||
|
||||
## Overview
|
||||
|
||||
The host-level pipeline is responsible for:
|
||||
1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C
|
||||
2. **Block-to-tile mapping**: Assigning each thread block to a specific tile
|
||||
3. **Creating tile windows**: Establishing sliding windows over tensor views
|
||||
4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM
|
||||
5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
// 1. Calculate problem dimensions and tile coverage
|
||||
// 2. Map thread block to tile coordinates
|
||||
// 3. Create tile windows over A and B
|
||||
// 4. Call block-level pipeline to compute
|
||||
// 5. Store result to C
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Calculate Problem Dimensions and Tile Coverage
|
||||
|
||||
```cpp
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{}); // 256
|
||||
const auto NPerBlock = BlockTile::at(number<1>{}); // 128
|
||||
const auto KPerBlock = BlockTile::at(number<2>{}); // 32
|
||||
|
||||
// Number of block tiles needed to cover C matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Extract problem dimensions** from tensor descriptors:
|
||||
- `M = 512`: Rows in A and C
|
||||
- `N = 256`: Columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
|
||||
2. **Get block tile sizes** from the `BlockTile` configuration:
|
||||
- `MPerBlock = 256`: Each block processes 256 rows
|
||||
- `NPerBlock = 128`: Each block processes 128 columns
|
||||
- `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration
|
||||
|
||||
3. **Calculate tile coverage**:
|
||||
- `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction
|
||||
- `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction
|
||||
- **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**!
|
||||
|
||||
### Visual Representation:
|
||||
|
||||
```
|
||||
Matrix C (512 × 256):
|
||||
┌──────────────────────┬──────────────────────┐
|
||||
│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 0 │ Block 1 │
|
||||
│ │ │
|
||||
├──────────────────────┼──────────────────────┤
|
||||
│ Tile (1,0) │ Tile (1,1) │
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 2 │ Block 3 │
|
||||
│ │ │
|
||||
└──────────────────────┴──────────────────────┘
|
||||
↑
|
||||
num_tile_m = 2
|
||||
|
||||
Total blocks needed = 2 × 2 = 4 blocks
|
||||
|
||||
Each block computes one 256×128 tile of the output matrix C.
|
||||
```
|
||||
|
||||
### How Blocks Cover Matrices A and B:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬──────┐ ┌─────────────┬──────┐
|
||||
│ Block 0,2 │ K │ │ Block 0,1 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 0-255 │ │ │ 0-127 │ │
|
||||
├─────────────┼──────┤ ├─────────────┼──────┤
|
||||
│ Block 1,3 │ K │ │ Block 2,3 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 256-511 │ │ │ 128-255 │ │
|
||||
└─────────────┴──────┘ └─────────────┴──────┘
|
||||
256 rows 64 cols 128 rows 64 cols
|
||||
|
||||
Each block needs to iterate over K dimension (64/32 = 2 iterations)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Map Thread Block to Tile Coordinates
|
||||
|
||||
```cpp
|
||||
// Get block id (0 to total_blocks - 1)
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
// Map block id to 2D tile coordinates
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate
|
||||
const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates.
|
||||
|
||||
### The `MakeBlock2TileMap` Function:
|
||||
|
||||
```cpp
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
// Create a merge transform: (N0, M0) → linear index
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
// Convert linear block_id back to 2D coordinates
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
// Return (m_idx, n_idx) - note the swap!
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### In Our Example (2×2 Grid):
|
||||
|
||||
```cpp
|
||||
// Block 0:
|
||||
id_block = 0
|
||||
tile_id = block2tile(0) = (0, 0) // Top-left tile
|
||||
tile_id_m = 0, tile_id_n = 0
|
||||
|
||||
// Block 1:
|
||||
id_block = 1
|
||||
tile_id = block2tile(1) = (1, 0) // Bottom-left tile
|
||||
tile_id_m = 1, tile_id_n = 0
|
||||
|
||||
// Block 2:
|
||||
id_block = 2
|
||||
tile_id = block2tile(2) = (0, 1) // Top-right tile
|
||||
tile_id_m = 0, tile_id_n = 1
|
||||
|
||||
// Block 3:
|
||||
id_block = 3
|
||||
tile_id = block2tile(3) = (1, 1) // Bottom-right tile
|
||||
tile_id_m = 1, tile_id_n = 1
|
||||
```
|
||||
|
||||
**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing!
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Calculate Tile Origin and Create Tile Windows
|
||||
|
||||
```cpp
|
||||
// Calculate the starting position of this tile in the global matrix
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128
|
||||
|
||||
// Create tile windows over A and B tensor views
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, // Tensor view over A
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // Window size: 256×32
|
||||
{tile_origin_m, 0} // Origin: varies by block
|
||||
);
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, // Tensor view over B
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), // Window size: 128×32
|
||||
{tile_origin_n, 0} // Origin: varies by block
|
||||
);
|
||||
```
|
||||
|
||||
### Tile Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0 (Tile 0,0):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 1 (Tile 1,0):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 2 (Tile 0,1):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
|
||||
// Block 3 (Tile 1,1):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
```
|
||||
|
||||
### What are Tile Windows?
|
||||
|
||||
A **tile window** is a **sliding window** over a larger tensor view. It:
|
||||
- Defines a **rectangular region** within the tensor
|
||||
- Has a **fixed size** (e.g., 256×32 for A)
|
||||
- Has an **origin** (starting position)
|
||||
- Can be **moved** to access different regions
|
||||
### Visual Representation (Block 0 Example):
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │
|
||||
│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │
|
||||
│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │
|
||||
│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │
|
||||
│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │
|
||||
│ │ │ ├─────────────┼─────────────┤
|
||||
├─────────────┼─────────────┤ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
└─────────────┴─────────────┘ └─────────────┴─────────────┘
|
||||
Origin: (0, 0) Origin: (0, 0)
|
||||
Covers rows 0-255 Covers rows 0-127
|
||||
Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration)
|
||||
```
|
||||
|
||||
**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration.
|
||||
|
||||
### Tile Window Properties:
|
||||
|
||||
```cpp
|
||||
// Tile window structure (conceptual):
|
||||
struct tile_window {
|
||||
TensorView& tensor_view; // Reference to underlying tensor
|
||||
Tuple window_lengths; // Size of the window (256, 32)
|
||||
MultiIndex window_origin; // Starting position (0, 0)
|
||||
|
||||
// Can move the window:
|
||||
void move(MultiIndex step); // Shift window by step
|
||||
|
||||
// Access data through the window:
|
||||
auto load(); // Load data from windowed region
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
### Tile Window Movement: Iterating Over K Dimension
|
||||
|
||||
In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K
|
||||
│ ┃ 256×32 ┃ │ ║ 256×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ │ │
|
||||
│ Block 1's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
|
||||
Matrix B (256 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │
|
||||
│ ┃ 128×32 ┃ │ ║ 128×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ Block 2's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
```
|
||||
|
||||
### How Windows Move (Conceptual - handled by block pipeline):
|
||||
|
||||
```cpp
|
||||
// Iteration 0:
|
||||
a_block_window origin: (tile_origin_m, 0) // K columns 0-31
|
||||
b_block_window origin: (tile_origin_n, 0) // K columns 0-31
|
||||
// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31]
|
||||
|
||||
// Move windows to next K position:
|
||||
move_tile_window(a_block_window, {0, 32});
|
||||
move_tile_window(b_block_window, {0, 32});
|
||||
|
||||
// Iteration 1:
|
||||
a_block_window origin: (tile_origin_m, 32) // K columns 32-63
|
||||
b_block_window origin: (tile_origin_n, 32) // K columns 32-63
|
||||
// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63]
|
||||
|
||||
// Final result:
|
||||
// C_tile = C_partial_0 + C_partial_1
|
||||
```
|
||||
|
||||
**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C.
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Delegate to Block-Level Pipeline
|
||||
|
||||
```cpp
|
||||
// Get the block-level pipeline from policy
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
// Calculate number of K iterations needed
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2
|
||||
|
||||
// Allocate shared memory (LDS) for block-level computation
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
|
||||
// Call block-level pipeline to compute C tile
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation
|
||||
2. **Calculate K iterations**: How many times to iterate over the K dimension
|
||||
- In our example: `K=64, KPerBlock=32` → **2 iterations**
|
||||
- Each iteration processes 32 elements of the K dimension
|
||||
- Results are accumulated across iterations
|
||||
|
||||
3. **Allocate shared memory**:
|
||||
- `__shared__` declares memory shared by all threads in the block
|
||||
- `GetStaticLDSSize()` returns the required size in bytes
|
||||
- This memory is used for:
|
||||
- Staging data from DRAM → LDS
|
||||
- Cooperative loading by threads
|
||||
- Fast access during computation
|
||||
|
||||
4. **Execute block pipeline**:
|
||||
- Takes A and B tile windows as input
|
||||
- Performs the GEMM computation: `C_tile = A_tile × B_tile`
|
||||
- Returns result in `c_block_tile` (stored in VGPRs - registers)
|
||||
|
||||
### Memory Hierarchy During Computation:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) - Slowest, Largest │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A matrix │ │ B matrix │ │ C matrix │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load ↑ store
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A_tile │ │ B_tile │ ← Staged here │
|
||||
│ │ (p_smem) │ │ (p_smem) │ │
|
||||
│ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ c_block_tile (accumulated result) │ │
|
||||
│ │ Computation happens here using MFMA instructions │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Block Pipeline Responsibilities:
|
||||
|
||||
The block pipeline (called here) will:
|
||||
1. Load A and B tiles from DRAM → LDS (cooperative loading)
|
||||
2. Distribute work among warps
|
||||
3. Each warp loads its portion from LDS → VGPRs
|
||||
4. Perform MFMA operations: `C += A × B`
|
||||
5. Accumulate results in VGPRs
|
||||
6. Return final `c_block_tile` in registers
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Store Results to DRAM
|
||||
|
||||
```cpp
|
||||
// Create a tile window over C for writing results
|
||||
auto c_window = make_tile_window(
|
||||
c_dram, // Tensor view over C
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}), // Window size: 256×128
|
||||
{tile_origin_m, tile_origin_n} // Origin: varies by block
|
||||
);
|
||||
|
||||
// Store computed tile from VGPRs to DRAM
|
||||
store_tile(c_window, c_block_tile);
|
||||
```
|
||||
|
||||
### C Window Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0: Writes to top-left tile
|
||||
c_window origin: (0, 0) → writes to C[0:255, 0:127]
|
||||
|
||||
// Block 1: Writes to bottom-left tile
|
||||
c_window origin: (256, 0) → writes to C[256:511, 0:127]
|
||||
|
||||
// Block 2: Writes to top-right tile
|
||||
c_window origin: (0, 128) → writes to C[0:255, 128:255]
|
||||
|
||||
// Block 3: Writes to bottom-right tile
|
||||
c_window origin: (256, 128) → writes to C[256:511, 128:255]
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Create C tile window**:
|
||||
- Size: 256×128 (matches our block tile size)
|
||||
- Origin: Varies by block - each block writes to its assigned region
|
||||
- This window defines **where** to write the results
|
||||
|
||||
2. **Store tile to DRAM**:
|
||||
- `c_block_tile`: Computed results in VGPRs (registers)
|
||||
- `c_window`: Destination window in DRAM
|
||||
- `store_tile()`: Efficiently writes data from registers → DRAM
|
||||
|
||||
### The `store_tile` Function:
|
||||
|
||||
Recall from our earlier discussion, `store_tile` does:
|
||||
|
||||
```cpp
|
||||
template <typename TileWindow, typename DistributedTensor>
|
||||
void store_tile(TileWindow& tile_window_tmp,
|
||||
const DistributedTensor& dstr_tensor)
|
||||
{
|
||||
// 1. Extract tile distribution from distributed tensor
|
||||
using TileDstr = typename DistributedTensor::TileDistribution;
|
||||
|
||||
// 2. Upgrade simple tile window to one with distribution
|
||||
auto tile_window = make_tile_window(
|
||||
tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
TileDstr{} // Add distribution info
|
||||
);
|
||||
|
||||
// 3. Store using vectorized writes
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
```
|
||||
|
||||
### Memory Flow:
|
||||
|
||||
```
|
||||
VGPRs (Registers) DRAM (Global Memory)
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ c_block_tile │ │ C matrix │
|
||||
│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │
|
||||
│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │
|
||||
│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │
|
||||
│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │
|
||||
│ └───┴───┴───┴───┘ │ │ │ │ │
|
||||
│ Distributed across │ │ └───────────────┘ │
|
||||
│ threads/warps │ │ Origin: (0, 0) │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
|
||||
Each thread writes its portion using vector stores (e.g., float4)
|
||||
```
|
||||
|
||||
### Store Optimization:
|
||||
|
||||
The `store_tile` function:
|
||||
- Uses **vectorized stores** (write multiple elements at once)
|
||||
- Ensures **coalesced memory access** (adjacent threads write adjacent memory)
|
||||
- Respects **tile distribution** (each thread knows what data it owns)
|
||||
- Handles **out-of-bounds** checking (for partial tiles at boundaries)
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow Visualization
|
||||
|
||||
Let's trace the complete flow for **Block 0** (other blocks follow the same pattern):
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Calculate Tile Coverage │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ M=512, N=256, K=64 │ │
|
||||
│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │
|
||||
│ │ num_tile_m = ceil(512/256) = 2 │ │
|
||||
│ │ num_tile_n = ceil(256/128) = 2 │ │
|
||||
│ │ Total blocks needed = 2 × 2 = 4 blocks │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Map Block to Tile (Block 0 example) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Block ID: 0 │ │
|
||||
│ │ Tile coordinates: (0, 0) - top-left tile │ │
|
||||
│ │ Tile origin: (0, 0) │ │
|
||||
│ │ │ │
|
||||
│ │ (Blocks 1,2,3 get different tile coordinates) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Create Tile Windows │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ a_block_window: 256×32 starting at (0,0) over A │ │
|
||||
│ │ b_block_window: 128×32 starting at (0,0) over B │ │
|
||||
│ │ Windows initially cover K columns 0-31 │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Execute Block Pipeline (2 K iterations) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Allocate shared memory (LDS) │ │
|
||||
│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 0 (K=0-31): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │
|
||||
│ │ └─ Move windows: {0, 32} │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 1 (K=32-63): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │
|
||||
│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │
|
||||
│ │ │ │
|
||||
│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 5: Store Results │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Create c_window: 256×128 starting at (0,0) over C │ │
|
||||
│ │ store_tile(c_window, c_block_tile) │ │
|
||||
│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │
|
||||
│ │ │ │
|
||||
│ │ Block 0 writes to C[0:255, 0:127] │ │
|
||||
│ │ (Other blocks write to their respective regions) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
All 4 blocks execute in parallel, each computing its assigned 256×128 tile!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. **Tile Coverage**
|
||||
- Determines how many thread blocks are needed
|
||||
- Each block processes one tile of the output matrix C
|
||||
- Calculated as `ceil(dimension / tile_size)`
|
||||
|
||||
### 2. **Block-to-Tile Mapping**
|
||||
- Maps linear block ID to 2D tile coordinates
|
||||
- Uses column-major ordering for better memory coalescing
|
||||
- Each block knows which tile it's responsible for
|
||||
|
||||
### 3. **Tile Windows**
|
||||
- **Sliding windows** over larger tensor views
|
||||
- Define a rectangular region with fixed size and movable origin
|
||||
- Provide efficient, structured access to tensor data
|
||||
- Can be moved to access different regions (e.g., for K iterations)
|
||||
|
||||
### 4. **Memory Hierarchy**
|
||||
- **DRAM (Global)**: Largest, slowest - stores full matrices
|
||||
- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access
|
||||
- **VGPRs (Registers)**: Smallest, fastest - performs computation
|
||||
|
||||
### 5. **Data Flow**
|
||||
```
|
||||
DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM
|
||||
↑ ↓
|
||||
A, B matrices C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore:
|
||||
- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed
|
||||
- **Warp-level pipeline**: How warps perform MFMA operations
|
||||
- **Memory optimization**: LDS usage, bank conflicts, coalescing
|
||||
|
||||
The host level provides the **orchestration**, while the block and warp levels provide the **execution**!
|
||||
|
||||
464
tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
464
tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
@@ -0,0 +1,464 @@
|
||||
# PracticeGemmKernel: Understanding the Kernel Entry Point
|
||||
|
||||
This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views.
|
||||
|
||||
## Overview
|
||||
|
||||
The `PracticeGemmKernel` is a templated struct that:
|
||||
1. Takes raw device memory pointers for matrices A, B, and C
|
||||
2. Wraps them into **tensor views** - logical, structured views over physical memory
|
||||
3. Dispatches to the host-level pipeline for computation
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
// Step 1: Create tensor views over raw memory
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
// Step 2: Dispatch to host-level pipeline
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What are Tensor Views?
|
||||
|
||||
A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor.
|
||||
|
||||
### Key Components of a Tensor View:
|
||||
|
||||
1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers)
|
||||
2. **Raw Pointer**: Points to the actual data in memory
|
||||
3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A)
|
||||
4. **Strides**: How to navigate through memory to access elements
|
||||
5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction
|
||||
6. **Guaranteed Vector Stride**: The stride of those vectorizable elements
|
||||
|
||||
---
|
||||
|
||||
## The Memory Abstraction Hierarchy
|
||||
|
||||
CK Tile uses a three-layer abstraction to go from raw memory to structured tensors:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Layer 3: TENSOR VIEW │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ • Logical multi-dimensional structure │ │
|
||||
│ │ • Shape: (M, K) = (256, 32) │ │
|
||||
│ │ • Strides: (32, 1) for row-major layout │ │
|
||||
│ │ • Provides: operator[], coordinate-based access │ │
|
||||
│ │ • Knows: How to map (i,j) → linear offset │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 2: BUFFER VIEW │ │
|
||||
│ │ ┌─────────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ • Linear view of memory │ │ │
|
||||
│ │ │ • Pointer: p_data_ → device memory │ │ │
|
||||
│ │ │ • Size: Total number of elements │ │ │
|
||||
│ │ │ • Address space: global/LDS/generic │ │ │
|
||||
│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │
|
||||
│ │ └─────────────────────────────────────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 1: RAW PHYSICAL MEMORY │ │
|
||||
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │
|
||||
│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │
|
||||
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │
|
||||
│ │ ↑ │ │
|
||||
│ │ p_a (raw pointer from hipMalloc) │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `make_naive_tensor_view`
|
||||
|
||||
Let's break down the function call for matrix A:
|
||||
|
||||
```cpp
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer to device memory
|
||||
make_tuple(M, K), // Shape: (256, 32)
|
||||
make_tuple(stride_a, 1), // Strides: (32, 1) - row-major
|
||||
number<8>{}, // Guaranteed vector length
|
||||
number<1>{} // Guaranteed vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Function Signature:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
// Step 1: Create tensor descriptor (shape + stride information)
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
// Step 2: Create buffer view (pointer + size + address space)
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
// Step 3: Combine into tensor view
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Parameter Breakdown
|
||||
|
||||
### 1. **Template Parameter: `address_space_enum::global`**
|
||||
|
||||
Specifies where the memory lives:
|
||||
- `global`: GPU global memory (DRAM) - slowest but largest
|
||||
- `lds`: Local Data Share (shared memory) - fast, limited size
|
||||
- `generic`: Generic address space
|
||||
- `vgpr`: Vector General Purpose Registers - fastest, smallest
|
||||
|
||||
In our case, `global` means the data is in GPU DRAM.
|
||||
|
||||
### 2. **`p_a` - Raw Pointer**
|
||||
|
||||
The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data.
|
||||
|
||||
### 3. **`make_tuple(M, K)` - Shape/Lengths**
|
||||
|
||||
Defines the logical dimensions of the tensor:
|
||||
- For matrix A: `(256, 32)` means 256 rows, 32 columns
|
||||
- This is the **logical view**, independent of how data is physically laid out
|
||||
|
||||
### 4. **`make_tuple(stride_a, 1)` - Strides**
|
||||
|
||||
Defines how to navigate through memory:
|
||||
- **Stride for dimension 0 (rows)**: `stride_a = K = 32`
|
||||
- To move to the next row, skip 32 elements
|
||||
- **Stride for dimension 1 (columns)**: `1`
|
||||
- To move to the next column, skip 1 element
|
||||
|
||||
**Row-major layout example:**
|
||||
```
|
||||
Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...]
|
||||
↑ ↑
|
||||
Row 0 starts here Row 1 starts here (offset = 32)
|
||||
|
||||
To access element A[i][j]:
|
||||
offset = i * stride_a + j * 1
|
||||
= i * 32 + j
|
||||
```
|
||||
|
||||
### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length**
|
||||
|
||||
This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."**
|
||||
|
||||
#### Why is this important?
|
||||
|
||||
Modern GPUs can load multiple elements in one instruction (vectorized loads):
|
||||
- `float4`: Load 4 floats at once
|
||||
- `float8`: Load 8 floats at once (if supported)
|
||||
|
||||
By specifying `number<8>{}`, we're telling the system:
|
||||
- "You can safely use vector loads of up to 8 elements"
|
||||
- "The memory alignment and layout support this"
|
||||
|
||||
**Example:**
|
||||
```cpp
|
||||
// Without vectorization (slow):
|
||||
for (int j = 0; j < 8; j++) {
|
||||
data[j] = memory[offset + j]; // 8 separate loads
|
||||
}
|
||||
|
||||
// With vectorization (fast):
|
||||
float8 vec = *reinterpret_cast<float8*>(&memory[offset]); // 1 load!
|
||||
```
|
||||
|
||||
### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride**
|
||||
|
||||
This specifies the **stride between consecutive vectorizable elements** in the last dimension.
|
||||
|
||||
- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)"
|
||||
- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively
|
||||
|
||||
**Why does this matter?**
|
||||
|
||||
For efficient vectorized loads, elements must be:
|
||||
1. **Contiguous** (stride = 1) ✓
|
||||
2. **Aligned** properly in memory
|
||||
3. **Within the same cache line** (ideally)
|
||||
|
||||
If the stride were `2`, it would mean:
|
||||
```
|
||||
A[i][0] is at offset 0
|
||||
A[i][1] is at offset 2 (not 1!)
|
||||
A[i][2] is at offset 4
|
||||
```
|
||||
This would prevent efficient vectorization.
|
||||
|
||||
---
|
||||
|
||||
## What is a Buffer View?
|
||||
|
||||
A **buffer view** is the middle layer between raw memory and tensor view. It provides:
|
||||
|
||||
### Core Responsibilities:
|
||||
|
||||
1. **Memory Management**
|
||||
- Holds the raw pointer: `T* p_data_`
|
||||
- Tracks buffer size: `BufferSizeType buffer_size_`
|
||||
- Knows the address space: `global`, `lds`, etc.
|
||||
|
||||
2. **Vectorized Access**
|
||||
```cpp
|
||||
template <typename VectorType>
|
||||
CK_TILE_DEVICE VectorType get(index_t offset);
|
||||
```
|
||||
- Provides efficient vector loads/stores
|
||||
- Handles alignment requirements
|
||||
|
||||
3. **Bounds Checking** (optional)
|
||||
```cpp
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto get(index_t i, index_t linear_offset);
|
||||
```
|
||||
- Can optionally check if access is within bounds
|
||||
- Returns invalid value (default 0) for out-of-bounds access
|
||||
|
||||
4. **Address Space Awareness**
|
||||
- Uses different load/store instructions based on address space
|
||||
- Global memory: `global_load`, `global_store`
|
||||
- LDS: `ds_read`, `ds_write`
|
||||
|
||||
### Buffer View Structure:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
struct buffer_view
|
||||
{
|
||||
T* p_data_; // Raw pointer
|
||||
BufferSizeType buffer_size_; // Total elements
|
||||
remove_cvref_t<T> invalid_element_value_; // Value for OOB access
|
||||
|
||||
// Access operators
|
||||
const T& operator[](index_t i) const; // Read
|
||||
T& operator()(index_t i); // Write
|
||||
|
||||
// Vectorized access
|
||||
template <typename VectorType>
|
||||
VectorType get(index_t offset);
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Visual Example: Matrix A Memory Layout
|
||||
|
||||
Let's visualize how matrix A (256×32, fp16) is organized:
|
||||
|
||||
### Raw Physical Memory (Linear):
|
||||
```
|
||||
GPU DRAM Address Space:
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Byte 0 │
|
||||
│ ↓ │
|
||||
│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │
|
||||
│ ↑ ↑ │
|
||||
│ Row 0 (32 elements) Row 1 (32 elements) │
|
||||
│ │
|
||||
│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↑
|
||||
p_a (raw pointer)
|
||||
```
|
||||
|
||||
### Buffer View Layer:
|
||||
```
|
||||
buffer_view<address_space_enum::global, fp16_t, ...>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ p_data_ = p_a │
|
||||
│ buffer_size_ = 256 × 32 = 8,192 elements │
|
||||
│ address_space = global (DRAM) │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Linear indexing: buffer_view[i] → element at offset i │
|
||||
│ • Vectorized loads: get<float4>(offset) → load 4 fp16s at once│
|
||||
│ • Bounds checking: is offset < buffer_size_? │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Tensor View Layer:
|
||||
```
|
||||
tensor_view<buffer_view, tensor_descriptor>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Shape: (256, 32) │
|
||||
│ Strides: (32, 1) │
|
||||
│ Guaranteed vector length: 8 │
|
||||
│ Guaranteed vector stride: 1 │
|
||||
│ │
|
||||
│ Logical 2D View: │
|
||||
│ Col: 0 1 2 ... 31 │
|
||||
│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│
|
||||
│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │
|
||||
│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │
|
||||
│ ... │
|
||||
│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Multi-dimensional indexing: A[i][j] │
|
||||
│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │
|
||||
│ • Tile window creation: Extract sub-tensors │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow: Raw Memory → Tensor View
|
||||
|
||||
Let's trace the complete transformation for matrix A:
|
||||
|
||||
### Step 1: Kernel Launch (Host Side)
|
||||
```cpp
|
||||
// On host: Allocate device memory
|
||||
hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer
|
||||
|
||||
// Launch kernel
|
||||
kernel<<<grid, block>>>(p_a, p_b, p_c, M, N, K, ...);
|
||||
```
|
||||
|
||||
### Step 2: Inside Kernel (Device Side)
|
||||
```cpp
|
||||
// Receive raw pointer
|
||||
const fp16_t* p_a; // Points to GPU DRAM
|
||||
|
||||
// Step 2a: Create tensor descriptor
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(256, 32), // Shape
|
||||
make_tuple(32, 1), // Strides
|
||||
number<8>{}, // Vector length
|
||||
number<1>{} // Vector stride
|
||||
);
|
||||
// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8"
|
||||
|
||||
// Step 2b: Create buffer view
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer
|
||||
256 * 32 // Total elements
|
||||
);
|
||||
// buffer_view now wraps p_a with size and address space info
|
||||
|
||||
// Step 2c: Create tensor view
|
||||
auto a_dram = tensor_view{buffer_view, desc};
|
||||
// a_dram now provides structured, multi-dimensional access to p_a
|
||||
```
|
||||
|
||||
### Step 3: Using the Tensor View
|
||||
```cpp
|
||||
// Access element A[i][j]
|
||||
auto value = a_dram[make_tuple(i, j)];
|
||||
|
||||
// Create a tile window (sub-tensor)
|
||||
auto tile = make_tile_window(
|
||||
a_dram,
|
||||
make_tuple(16, 16), // 16×16 tile
|
||||
make_tuple(0, 0) // Starting at origin
|
||||
);
|
||||
|
||||
// Load tile into registers with vectorization
|
||||
auto tile_data = load_tile(tile); // Uses vector loads internally!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why This Abstraction?
|
||||
|
||||
### Benefits:
|
||||
|
||||
1. **Type Safety**: Can't accidentally access wrong dimensions
|
||||
2. **Performance**: Compiler knows about vectorization opportunities
|
||||
3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers)
|
||||
4. **Maintainability**: Logical structure separate from physical layout
|
||||
5. **Optimization**: Guaranteed vector properties enable aggressive optimizations
|
||||
|
||||
### Example: Without Tensor Views (Manual Indexing)
|
||||
```cpp
|
||||
// Ugly, error-prone, hard to optimize:
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j];
|
||||
// Hope the compiler vectorizes this? 🤞
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: With Tensor Views (Clean, Optimized)
|
||||
```cpp
|
||||
// Clean, safe, automatically vectorized:
|
||||
auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin);
|
||||
auto tile_data = load_tile(tile); // Vectorized loads guaranteed!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction:
|
||||
|
||||
1. **Raw Memory**: Linear array of bytes in GPU DRAM
|
||||
2. **Buffer View**: Adds size, address space, and vectorized access
|
||||
3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing
|
||||
|
||||
This abstraction enables:
|
||||
- ✅ Clean, readable code
|
||||
- ✅ Type-safe multi-dimensional access
|
||||
- ✅ Automatic vectorization
|
||||
- ✅ Flexible memory space handling
|
||||
- ✅ Efficient tile-based computation
|
||||
|
||||
The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation!
|
||||
|
||||
150
tutorial/ck_tile/01_naive_gemm/README.md
Normal file
150
tutorial/ck_tile/01_naive_gemm/README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# CK Tile Practice GEMM Example
|
||||
|
||||
This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system.
|
||||
|
||||
## CK Tile API Structure
|
||||
|
||||
In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**:
|
||||
|
||||
1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices
|
||||
2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads
|
||||
3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue)
|
||||
|
||||
## Overview
|
||||
|
||||
This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing:
|
||||
|
||||
- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch
|
||||
- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM
|
||||
- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory
|
||||
- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation
|
||||
|
||||
## Problem Setup and Data Flow
|
||||
|
||||
### Problem Size Configuration
|
||||
We set the problem size using the M, N and K variables:
|
||||
```cpp
|
||||
ck_tile::index_t M = 1024; // Number of rows in A and C
|
||||
ck_tile::index_t N = 512; // Number of columns in B and C
|
||||
ck_tile::index_t K = 256; // Number of columns in A, rows in B
|
||||
```
|
||||
|
||||
### Host Matrix Creation
|
||||
Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory:
|
||||
```cpp
|
||||
// Host tensors with proper strides
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides); // M × K
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides); // N × K
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides); // M × N
|
||||
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
// Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
a_device.ToDevice(a_host.data());
|
||||
```
|
||||
|
||||
### PracticeGemmShape Configuration
|
||||
A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile:
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave
|
||||
```
|
||||
- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix.
|
||||
|
||||
|
||||
- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile.
|
||||
- BlockTiles are further subdivided into WarpTiles.
|
||||
- WarpTiles over A and B similarly work together to calculate the WarpTile of C.
|
||||
|
||||
### Problem and Policy Composition
|
||||
```cpp
|
||||
// A Problem is composed from Shape and info about the data
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
// A Policy is created describing data-to-thread mapping
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
// A Kernel is then composed of Problem and Policy
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
```
|
||||
|
||||
### Kernel Launch
|
||||
`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`:
|
||||
```cpp
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCU>(
|
||||
gemm_kernel{}, // Kernel composed of Problem + Policy
|
||||
kGridSize, // Grid dimensions
|
||||
kBlockSize, // Block dimensions
|
||||
0, // Dynamic shared memory
|
||||
// Kernel arguments: device buffers and problem dimensions
|
||||
a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(),
|
||||
M, N, K, stride_a, stride_b, stride_c));
|
||||
```
|
||||
|
||||
### Result Verification
|
||||
The results from the kernel are compared with results from CPU based computation function:
|
||||
```cpp
|
||||
// CPU reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
|
||||
|
||||
// Device results
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
// Verify correctness
|
||||
bool pass = ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
```
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
The main program (`practice_gemm.cpp`) is the entry point for the runtime flow:
|
||||
|
||||
```cpp
|
||||
int main()
|
||||
{
|
||||
// 1. Define data types and problem sizes
|
||||
using ADataType = ck_tile::half_t;
|
||||
ck_tile::index_t M = 2048, N = 1024, K = 512;
|
||||
|
||||
// 2. Create host tensors and initialize
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
// 3. Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
|
||||
// 4. Configure tile shapes
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// 5. Launch kernel
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
float ave_time = ck_tile::launch_kernel(/*...*/);
|
||||
|
||||
// 6. Verify results
|
||||
bool pass = verify_results(a_host, b_host, c_host);
|
||||
|
||||
// 7. Print performance metrics
|
||||
print_performance_metrics(ave_time, M, N, K);
|
||||
}
|
||||
```
|
||||
|
||||
## Building and Running
|
||||
|
||||
```bash
|
||||
# From composable_kernel root directory
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_practice_gemm -j
|
||||
|
||||
# Run with sample sizes
|
||||
./bin/tile_example_practice_gemm
|
||||
```
|
||||
This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework.
|
||||
506
tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md
Normal file
506
tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md
Normal file
@@ -0,0 +1,506 @@
|
||||
# Practice GEMM: Step-by-Step Code Walkthrough
|
||||
|
||||
This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API.
|
||||
|
||||
## Overview
|
||||
|
||||
We'll implement `C = A × B` where:
|
||||
- `A` is an `M × K` matrix
|
||||
- `B` is an `N × K` matrix (note: transposed layout)
|
||||
- `C` is an `M × N` matrix
|
||||
|
||||
The implementation uses a hierarchical tiling strategy with two levels:
|
||||
1. **Block Tiles**: Processed by thread blocks
|
||||
2. **Wave Tiles**: Processed by warps (wavefronts) within blocks
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Define Data Types
|
||||
|
||||
```cpp
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We use `half_t` (FP16) for input matrices A and B.
|
||||
- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy
|
||||
- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity
|
||||
---
|
||||
|
||||
## Step 2: Define Problem Size
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `M = 512`: Number of rows in A and C
|
||||
- `N = 256`: Number of columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
- Strides define memory layout (row-major for A and C, transposed for B)
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Matrix A (M×K): Matrix B (N×K): Matrix C (M×N):
|
||||
[512 rows] [256 rows] [512 rows]
|
||||
[64 cols] [64 cols] [256 cols]
|
||||
stride = K stride = K stride = N
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Create Host Tensors
|
||||
|
||||
```cpp
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We create three tensors on the host (CPU) memory
|
||||
- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`)
|
||||
- `HostTensor` is a CK Tile utility class that manages CPU memory
|
||||
|
||||
**Stride explanation:**
|
||||
- For A: `stride_a = K` means moving to the next row requires skipping K elements
|
||||
- For B: `stride_b = K` means B is stored in transposed format
|
||||
- For C: `stride_c = N` means row-major layout
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Initialize Tensors with Random Data
|
||||
|
||||
```cpp
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- A and B are filled with random values in the range [-5.0, 5.0]
|
||||
- C is initialized to zero (will store the output)
|
||||
|
||||
**Optional: Print Tensor Contents**
|
||||
```cpp
|
||||
// Commented out in the code, but available for debugging:
|
||||
// a_host.print_first_n(10); // Print first 10 elements of A
|
||||
```
|
||||
|
||||
The `print_first_n()` helper function can display tensor contents for debugging purposes.
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Allocate Device Memory and Transfer Data
|
||||
|
||||
```cpp
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `DeviceMem` allocates GPU memory matching the size of host tensors
|
||||
- The constructor **automatically transfers data from host to device**
|
||||
- This is a convenience wrapper around `hipMalloc` and `hipMemcpy`
|
||||
|
||||
**Memory Flow:**
|
||||
```
|
||||
CPU (Host) GPU (Device)
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
│ c_host │ ────────> │c_device │
|
||||
└─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 6: Configure Hierarchical Tiling
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We define a two-level tiling hierarchy for the GEMM computation
|
||||
|
||||
### Block Tile (256 × 128 × 32)
|
||||
- **256**: M dimension per block (rows of A and C)
|
||||
- **128**: N dimension per block (columns of B and C)
|
||||
- **32**: K dimension per block (inner dimension)
|
||||
- Each block tile is processed by one **thread block** (256 threads)
|
||||
|
||||
### Wave Tile (16 × 16 × 16)
|
||||
- **16 × 16**: Output tile dimensions (M × N) per warp iteration
|
||||
- **16**: K dimension per warp iteration
|
||||
- Each wave tile is processed by one **warp** (64 threads on AMD GPUs)
|
||||
|
||||
**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile
|
||||
|
||||
**Important Note:**
|
||||
In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem.
|
||||
|
||||
### Tiling Visualization:
|
||||
|
||||
#### Matrix A (M × K = 256 × 32):
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles (16×16) │
|
||||
│ │ 16│ 16 │ in M×K space │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ │
|
||||
│ │ .. │ .. │ 16 tiles in M │
|
||||
│ ├────┼────┤ 2 tiles in K │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix B (N × K = 128 × 32):
|
||||
```
|
||||
┌──────────────────────────────┐
|
||||
│ One Block Tile (128 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles │
|
||||
│ │ 16│ 16 │ (16×16) │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ 8 tiles in N │
|
||||
│ │ .. │ .. │ 2 tiles in K │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix C (M × N = 256 × 128) - Output:
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 128) │
|
||||
│ │
|
||||
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
|
||||
│ │16× │ │ │ │ │ │ │ │ │
|
||||
│ │ 16 │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
|
||||
│ │
|
||||
│ 16 wave tiles in M direction │
|
||||
│ 8 wave tiles in N direction │
|
||||
│ Total: 128 wave tiles (16×16 each) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### How Wave Tiles Combine (C = A × B):
|
||||
```
|
||||
Matrix A Matrix B (stored transposed N×K) Matrix C
|
||||
(256×32) (128×32) (256×128)
|
||||
|
||||
Row of A tiles: Row of B tiles: One wave tile in C:
|
||||
┌────┬────┐ ┌────┬────┐ ┌────┐
|
||||
│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16)
|
||||
└────┴────┘ └────┴────┘ └────┘
|
||||
16×16 each 16×16 each
|
||||
|
||||
Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ
|
||||
↑ ↑
|
||||
K=0..15 K=16..31
|
||||
|
||||
Each wave tile in C is computed by:
|
||||
- Taking one row of wave tiles from A (2 tiles along K)
|
||||
- Taking one row of wave tiles from B (2 tiles along K)
|
||||
Note: B is stored transposed (N×K), so a "row" in storage corresponds
|
||||
to a "column" in the logical B^T matrix used in computation
|
||||
- Performing dot product: Σ(A_k × B_k^T) for k=0,1
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B
|
||||
- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory
|
||||
- This is the fundamental operation repeated across all 128 wave tiles in C
|
||||
- Each warp computes one wave tile using MFMA instructions
|
||||
|
||||
---
|
||||
|
||||
## Step 7: Create Shape, Problem, and Policy Structs
|
||||
|
||||
```cpp
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. **Shape Struct**
|
||||
Encapsulates all tile shape information (BlockTile and WaveTile dimensions).
|
||||
|
||||
### 2. **Problem Struct**
|
||||
Holds complete problem description:
|
||||
- Data types (ADataType, BDataType, CDataType, AccDataType)
|
||||
- Shape information (BlockTile, WaveTile)
|
||||
|
||||
In more complex examples, this would also include:
|
||||
- Data layouts (row-major, column-major)
|
||||
- Mathematical operations (e.g., transposed GEMM)
|
||||
|
||||
### 3. **Policy Struct**
|
||||
Describes data movement and thread-to-data mapping:
|
||||
- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions
|
||||
- In more complex kernels, includes:
|
||||
- DRAM access patterns
|
||||
- LDS (Local Data Share) usage strategies
|
||||
- Thread distribution within blocks
|
||||
|
||||
**CK Tile Design Pattern:**
|
||||
```
|
||||
Kernel = Problem + Policy + Epilogue
|
||||
↑ ↑ ↑
|
||||
(What) (How) (Post-processing)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 8: Calculate Grid and Block Dimensions
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### Grid Size Calculation
|
||||
```cpp
|
||||
kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N)
|
||||
= ceil(512 / 256) × ceil(256 / 128)
|
||||
= 2 × 2
|
||||
= 4 thread blocks
|
||||
```
|
||||
|
||||
Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction).
|
||||
|
||||
### Block Configuration
|
||||
- `kBlockSize = 256`: Each thread block has 256 threads
|
||||
- 256 threads / 64 threads per warp = **4 warps per block**
|
||||
- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity)
|
||||
|
||||
**Thread Hierarchy:**
|
||||
```
|
||||
GPU
|
||||
└── 1 Thread Block (Grid)
|
||||
└── 256 Threads
|
||||
├── Warp 0 (threads 0-63)
|
||||
├── Warp 1 (threads 64-127)
|
||||
├── Warp 2 (threads 128-191)
|
||||
└── Warp 3 (threads 192-255)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 9: Create and Launch the Kernel
|
||||
|
||||
```cpp
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. Kernel Composition
|
||||
```cpp
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
```
|
||||
The kernel is composed from Problem and Policy structs, following the CK Tile design pattern.
|
||||
|
||||
### 2. Kernel Launch
|
||||
`launch_kernel()` is a CK Tile utility that:
|
||||
- Launches the GPU kernel using HIP runtime
|
||||
- Measures execution time
|
||||
- Returns average execution time in milliseconds
|
||||
|
||||
### 3. Launch Parameters
|
||||
- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled
|
||||
- **Grid size**: `kGridSize = 1` - number of thread blocks
|
||||
- **Block size**: `kBlockSize = 256` - threads per block
|
||||
- **Shared memory**: `0` - no dynamic shared memory in this example
|
||||
- **Kernel arguments**: Device pointers and problem dimensions
|
||||
|
||||
### 4. Kernel Execution Flow
|
||||
```
|
||||
launch_kernel() calls gemm_kernel.operator()()
|
||||
↓
|
||||
PracticeGemmKernel::operator()
|
||||
↓
|
||||
Creates tensor views over device memory
|
||||
↓
|
||||
Calls block-level pipeline
|
||||
↓
|
||||
Block pipeline calls warp-level pipeline
|
||||
↓
|
||||
Warp pipeline calls MFMA instructions
|
||||
↓
|
||||
Results written back to C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 10: Verify Results
|
||||
|
||||
```cpp
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// Reference gemm on CPU
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
|
||||
// Copy GPU results back to host
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
|
||||
// Compare results
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. CPU Reference Implementation
|
||||
```cpp
|
||||
reference_basic_gemm<...>(a_host, b_host, c_host_ref);
|
||||
```
|
||||
Computes GEMM on CPU using a simple nested loop implementation (ground truth).
|
||||
|
||||
### 2. Copy GPU Results to Host
|
||||
```cpp
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
```
|
||||
Transfers the computed result from GPU memory back to CPU for comparison.
|
||||
|
||||
### 3. Error Checking
|
||||
```cpp
|
||||
ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
```
|
||||
Compares GPU and CPU results element-wise with tolerance:
|
||||
- **Relative error**: 1e-3 (0.1%)
|
||||
- **Absolute error**: 1e-3
|
||||
|
||||
**Verification Flow:**
|
||||
```
|
||||
CPU GPU
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
└─────────┘ └─────────┘
|
||||
│ │
|
||||
↓ ↓
|
||||
reference_gemm() GPU kernel
|
||||
│ │
|
||||
↓ ↓
|
||||
┌──────────┐ ┌──────────┐
|
||||
│c_host_ref│ │c_device │
|
||||
└──────────┘ └──────────┘
|
||||
│ │
|
||||
│ ↓
|
||||
│ FromDevice()
|
||||
│ │
|
||||
↓ ↓
|
||||
└────> check_err() <───┘
|
||||
│
|
||||
↓
|
||||
Pass/Fail
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Execution Flow Summary
|
||||
|
||||
```
|
||||
1. Define data types (FP16 inputs, FP32 output)
|
||||
↓
|
||||
2. Set problem size (M=256, N=128, K=32)
|
||||
↓
|
||||
3. Create host tensors and initialize with random data
|
||||
↓
|
||||
4. Allocate device memory and transfer data (CPU → GPU)
|
||||
↓
|
||||
5. Configure hierarchical tiling (BlockTile, WaveTile)
|
||||
↓
|
||||
6. Create Shape, Problem, and Policy structs
|
||||
↓
|
||||
7. Calculate grid/block dimensions (1 block, 256 threads)
|
||||
↓
|
||||
8. Compose and launch kernel (Problem + Policy)
|
||||
↓
|
||||
9. Execute GEMM on GPU
|
||||
│ ├─ Block-level pipeline
|
||||
│ ├─ Warp-level pipeline
|
||||
│ └─ MFMA instructions
|
||||
↓
|
||||
10. Verify results (compare GPU vs CPU reference)
|
||||
↓
|
||||
11. Calculate and print performance metrics
|
||||
↓
|
||||
12. Return success/failure
|
||||
```
|
||||
|
||||
---
|
||||
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmBlockPolicy>
|
||||
struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t NPerBlock = BlockTile::at(number<1>{});
|
||||
static constexpr index_t KPerBlock = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr index_t MPerWave = WaveTile::at(number<0>{});
|
||||
static constexpr index_t NPerWave = WaveTile::at(number<1>{});
|
||||
static constexpr index_t KPerWave = WaveTile::at(number<2>{});
|
||||
|
||||
using BlockGemm =
|
||||
remove_cvref_t<decltype(Policy::template GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers
|
||||
b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp"
|
||||
#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmBlockPipelineProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = Shape_;
|
||||
};
|
||||
|
||||
struct PracticeGemmBlockPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline()
|
||||
{
|
||||
return PracticeGemmWarpPipelineASmemBSmemCreg<Problem>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
using ADataType = typename Problem_::ADataType;
|
||||
using BDataType = typename Problem_::BDataType;
|
||||
using CDataType = typename Problem_::CDataType;
|
||||
using AccDataType = typename Problem_::AccDataType;
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram_ref) const
|
||||
{
|
||||
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{});
|
||||
const auto NPerBlock = BlockTile::at(number<1>{});
|
||||
const auto KPerBlock = BlockTile::at(number<2>{});
|
||||
|
||||
// Number of block tile in the N direction to cover C (resultant) matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock);
|
||||
// Number of block tile in the M direction to cover C (resultant) matrix
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock);
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 0)
|
||||
// {
|
||||
// printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n);
|
||||
// printf("total number of tiles: %d\n", num_tile_m * num_tile_n);
|
||||
// }
|
||||
|
||||
// Get block id
|
||||
const auto id_block =
|
||||
get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1
|
||||
|
||||
// Map block id to tile id
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{});
|
||||
const auto tile_id_n = tile_id.at(number<1>{});
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 15)
|
||||
// {
|
||||
// printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n);
|
||||
// }
|
||||
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock;
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock;
|
||||
|
||||
// create a tile window over dram for A and B
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {tile_origin_m, 0});
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {tile_origin_n, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
auto c_window = make_tile_window(c_dram,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{tile_origin_m, tile_origin_n});
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp"
|
||||
#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmHostProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = remove_cvref_t<Shape_>;
|
||||
};
|
||||
|
||||
struct PracticeGemmHostPolicy
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline()
|
||||
{
|
||||
using PracticeGemmBlockPipelineProblem_ =
|
||||
PracticeGemmBlockPipelineProblem<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename Problem::Shape>;
|
||||
return PracticeGemmBlockPipelineAGmemBGmemCreg<PracticeGemmBlockPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
131
tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp
Normal file
131
tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "practice_gemm.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
// TODO: GemmTypeConfig
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
// ArgParser
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
// tensors on host (cpu)
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
|
||||
// initialize tensors
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
|
||||
// Print the tensors using the new print_first_n member function
|
||||
// std::cout << "Tensor A (first 10 elements): ";
|
||||
// a_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor B (first 10 elements): ";
|
||||
// b_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor C (first 10 elements): ";
|
||||
// c_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// Create device tensors of same size as host tensors and copy data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
|
||||
// TODO: BlockTileConfig
|
||||
// constexpr ck_tile::index_t warpSize = 64;
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl;
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
|
||||
|
||||
std::cout << "kBlockSize: " << kBlockSize << std::endl;
|
||||
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
|
||||
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
69
tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp
Normal file
69
tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockTile_, typename WaveTile_>
|
||||
struct PracticeGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using WaveTile = remove_cvref_t<WaveTile_>;
|
||||
|
||||
static constexpr index_t BlockTile_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t BlockTile_N = BlockTile::at(number<1>{});
|
||||
static constexpr index_t BlockTile_K = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr index_t WaveTile_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "practice_gemm_shape",
|
||||
concat('x', BlockTile_M, BlockTile_N, BlockTile_K),
|
||||
concat('x', WaveTile_M, WaveTile_N, WaveTile_K));
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
36
tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp
Normal file
36
tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n)
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int K = b_n_k.mDesc.get_lengths()[1];
|
||||
|
||||
auto f = [&](auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_n_k(n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1);
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmWarpPolicy>
|
||||
struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using WaveGemmShape = remove_cvref_t<typename Problem::Shape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct PracticeGemmWarpPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user