mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
feat: add new optimized tutorial kernels
- Add 01_naive_gemm baseline implementation - Add 02_padding_k_first with PADDING_K_FIRST + MFMA_32x32x16 - Add 03_mfma_16x16x16 with PADDING_K_FIRST + MFMA_16x16x16 - Share common reference_gemm.hpp in parent gemm/ directory
This commit is contained in:
589
tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
589
tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
|
||||
## Overview
|
||||
|
||||
The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates:
|
||||
1. **Data movement** from DRAM → Registers → LDS
|
||||
2. **GEMM computation** using data in LDS
|
||||
3. **Iteration** over the K dimension when needed
|
||||
|
||||
This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C.
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Problem and Policy
|
||||
|
||||
Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern:
|
||||
|
||||
### Problem: `PracticeGemmBlockPipelineProblem`
|
||||
Contains:
|
||||
- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType`
|
||||
- **Shape information**: `BlockTile` and `WaveTile` dimensions
|
||||
|
||||
### Policy: `PracticeGemmBlockPolicy`
|
||||
Contains strategies for:
|
||||
1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`)
|
||||
- Defines how 256 threads in a block map to elements of a block tile
|
||||
- Each thread knows which elements to load/store from DRAM to its registers
|
||||
- We'll cover tile distribution construction in detail later
|
||||
|
||||
2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`)
|
||||
- Describes how data is logically organized in Local Data Share (LDS)
|
||||
- Optimizes for bank conflict avoidance and efficient access patterns
|
||||
- We'll cover LDS descriptor construction in detail later
|
||||
|
||||
3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`)
|
||||
- Returns the warp-level GEMM implementation
|
||||
|
||||
---
|
||||
|
||||
## Inputs and Outputs
|
||||
|
||||
```cpp
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
```
|
||||
|
||||
### Inputs:
|
||||
- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock)
|
||||
- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock)
|
||||
- `num_loop`: Number of iterations along K dimension
|
||||
- `p_smem`: Pointer to shared memory (LDS)
|
||||
|
||||
### Output:
|
||||
- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs)
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Walkthrough
|
||||
|
||||
### Step 1: Create LDS Tensor Views
|
||||
|
||||
```cpp
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// B tile in LDS (placed after A in shared memory)
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We partition the shared memory (`p_smem`) into two regions: one for A, one for B
|
||||
- We create **tensor views** over these LDS regions using descriptors from the policy
|
||||
- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Shared Memory (LDS):
|
||||
┌─────────────────────┬─────────────────────┐
|
||||
│ A Block Tile │ B Block Tile │
|
||||
│ (256×32 fp16) │ (128×32 fp16) │
|
||||
└─────────────────────┴─────────────────────┘
|
||||
↑ ↑
|
||||
p_a_lds p_b_lds
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Create Tile Windows for Data Movement
|
||||
|
||||
We create **6 tile windows** for different purposes:
|
||||
|
||||
#### 2a. DRAM → Registers (Load from DRAM)
|
||||
|
||||
```cpp
|
||||
auto a_copy_dram_window = make_tile_window(
|
||||
a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>()); // ← Tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `a_copy_dram_window` is a `tile_window_with_static_distribution`
|
||||
- The **tile distribution** tells each thread which elements to load from DRAM
|
||||
- This window will **slide along the K dimension** in the loop
|
||||
|
||||
#### 2b. Registers → LDS (Store to LDS)
|
||||
|
||||
```cpp
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
{0, 0}, // Origin at (0, 0) in LDS
|
||||
a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Uses the **same tile distribution** as `a_copy_dram_window`
|
||||
- This ensures each thread stores to LDS in the same pattern it loaded from DRAM
|
||||
- Origin is always `{0, 0}` because LDS is reused for each K iteration
|
||||
|
||||
#### 2c. LDS → Registers (GEMM Input)
|
||||
|
||||
```cpp
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0}); // No tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- This is a `tile_window_with_static_lengths` (no explicit distribution)
|
||||
- Used as input to the warp-level GEMM
|
||||
- The warp GEMM will handle its own thread mapping internally
|
||||
|
||||
**Similar windows are created for B:**
|
||||
- `b_copy_dram_window`: Load B from DRAM
|
||||
- `b_copy_lds_window`: Store B to LDS
|
||||
- `b_lds_gemm_window`: Read B from LDS for GEMM
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Create Distributed Tensors (VGPRs)
|
||||
|
||||
```cpp
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile; // Per-thread registers for A
|
||||
BBlockTile b_block_tile; // Per-thread registers for B
|
||||
```
|
||||
|
||||
#### What is `make_static_distributed_tensor`?
|
||||
|
||||
**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**.
|
||||
|
||||
**Key Properties:**
|
||||
1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers
|
||||
2. **Compile-time sized**: Buffer size determined by tile distribution at compile time
|
||||
3. **Zero-overhead**: All indexing and layout transformations happen at compile time
|
||||
|
||||
**How it works:**
|
||||
|
||||
```cpp
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
// Calculate per-thread storage size from tile distribution
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
// Per-thread register array (VGPRs)
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
```
|
||||
|
||||
**The tile distribution defines:**
|
||||
- **Which elements each thread owns** in the tile
|
||||
- **How many elements** each thread stores (buffer size)
|
||||
- **How elements are laid out** in each thread's registers
|
||||
|
||||
**Concrete Example for 256×32 tile with 256 threads:**
|
||||
|
||||
```
|
||||
Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values)
|
||||
Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values)
|
||||
Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values)
|
||||
...
|
||||
Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values)
|
||||
```
|
||||
|
||||
**Collectively:**
|
||||
- All 256 threads together hold the **entire 256×32 tile** (8192 elements)
|
||||
- Each thread's buffer lives in its **own VGPRs**
|
||||
- No two threads own the same element
|
||||
|
||||
**Distributed Ownership Analogy:**
|
||||
Think of a tile as a **jigsaw puzzle**:
|
||||
- The **tile distribution** is the cutting pattern
|
||||
- Each **thread** gets one puzzle piece (its slice)
|
||||
- Each **`static_distributed_tensor`** is a box holding all pieces
|
||||
- Each thread's **`thread_buf_`** is its individual piece in its own registers
|
||||
|
||||
---
|
||||
|
||||
### Step 4: The GEMM Loop
|
||||
|
||||
```cpp
|
||||
// Initialize C accumulator to zero
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
index_t iCounter = num_loop; // Number of K iterations
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// 1. Load from DRAM to registers
|
||||
a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs
|
||||
b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs
|
||||
|
||||
// 2. Move windows for next iteration
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32)
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32)
|
||||
|
||||
// 3. Store from registers to LDS
|
||||
store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS
|
||||
|
||||
// 4. Synchronize threads (ensure all data is in LDS)
|
||||
block_sync_lds();
|
||||
|
||||
// 5. Compute GEMM using data in LDS
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// 6. Synchronize threads (before overwriting LDS in next iteration)
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile; // Return accumulated result in registers
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Detailed Loop Breakdown
|
||||
|
||||
### Phase 1: Load (DRAM → VGPRs)
|
||||
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution)
|
||||
2. Data is loaded into **per-thread registers** (VGPRs)
|
||||
3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once)
|
||||
|
||||
**Example for Thread 0:**
|
||||
```
|
||||
Thread 0 loads:
|
||||
A[0,0:7] (8 fp16 values, one vector load)
|
||||
A[1,0:7] (8 fp16 values, one vector load)
|
||||
...
|
||||
```
|
||||
|
||||
### Phase 2: Move Windows
|
||||
|
||||
```cpp
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example)
|
||||
- This prepares for the next K iteration
|
||||
- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ...
|
||||
|
||||
**Visualization for Problem Size 512×256×64:**
|
||||
```
|
||||
Matrix A (512×64):
|
||||
┌─────────────────────────────────────┐
|
||||
│ Block 0: rows 0-255 │
|
||||
│ ┌──────────┬──────────┐ │
|
||||
│ │ K=0:31 │ K=32:63 │ │ ← Window slides right
|
||||
│ │ Iter 0 │ Iter 1 │ │
|
||||
│ └──────────┴──────────┘ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Phase 3: Store (VGPRs → LDS)
|
||||
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread writes **its elements** from registers to LDS
|
||||
2. Uses the **same distribution** as the DRAM load
|
||||
3. Data is now in **shared memory**, accessible to all threads in the block
|
||||
|
||||
**Why this step?**
|
||||
- GEMM computation needs **all threads** to access **all data**
|
||||
- Registers are per-thread; LDS is shared across the block
|
||||
- LDS acts as a "staging area" for collaborative computation
|
||||
|
||||
### Phase 4: Synchronize
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- All threads in the block **wait** until everyone has finished storing to LDS
|
||||
- Ensures no thread starts reading from LDS before all writes are complete
|
||||
- Critical for correctness!
|
||||
|
||||
### Phase 5: GEMM Computation
|
||||
|
||||
```cpp
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. The warp-level GEMM reads data from LDS
|
||||
2. Performs matrix multiplication using MFMA instructions
|
||||
3. Accumulates results into `c_block_tile` (in registers)
|
||||
|
||||
**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results.
|
||||
|
||||
### Phase 6: Synchronize Again
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- Ensures all threads have finished reading from LDS
|
||||
- Safe to overwrite LDS in the next iteration
|
||||
|
||||
---
|
||||
|
||||
## Memory Flow Diagram
|
||||
|
||||
```
|
||||
Iteration 0 (K=0:31):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 0:31] │ │ thread) │ │ │
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (VGPRs) │
|
||||
└──────────┘
|
||||
|
||||
Iteration 1 (K=32:63):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 32:63] │ │ thread) │ │ (reused)│
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (accum.) │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example: Problem Size 512×256×64
|
||||
|
||||
### Block 0 Computation
|
||||
|
||||
**Input:**
|
||||
- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially
|
||||
- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed)
|
||||
- `num_loop`: 2 (since K=64, KPerBlock=32)
|
||||
|
||||
**Iteration 0:**
|
||||
1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63]
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T`
|
||||
|
||||
**Iteration 1:**
|
||||
1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends)
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T`
|
||||
|
||||
**Output:**
|
||||
- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. Tile Distribution
|
||||
- **Maps threads to data elements** for load/store operations
|
||||
- Each thread knows exactly which elements it's responsible for
|
||||
- Enables **parallel, vectorized** memory access
|
||||
- **Same distribution** used for DRAM load and LDS store
|
||||
|
||||
### 2. Static Distributed Tensor
|
||||
- **Per-thread register storage** (VGPRs)
|
||||
- Each thread owns a **different slice** of the tile
|
||||
- **Compile-time sized** for zero-overhead abstraction
|
||||
- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile`
|
||||
|
||||
### 3. Tile Window Movement
|
||||
- Windows **slide** over larger tensors
|
||||
- Enables iteration over the K dimension
|
||||
- `move_tile_window(window, step)` updates the origin
|
||||
|
||||
### 4. LDS as Staging Area
|
||||
- **Shared memory** accessible to all threads in a block
|
||||
- Required because GEMM needs all threads to access all data
|
||||
- **Reused** across K iterations (same LDS buffer)
|
||||
|
||||
### 5. Synchronization
|
||||
- `block_sync_lds()` ensures memory consistency
|
||||
- **Before GEMM**: All stores to LDS are complete
|
||||
- **After GEMM**: All reads from LDS are complete
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `static_distributed_tensor` Mechanics
|
||||
|
||||
### How Tile Distribution Creates Per-Thread Storage
|
||||
|
||||
When you call:
|
||||
```cpp
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<fp16_t>(ABlockTileDistr{}));
|
||||
ABlockTile a_block_tile;
|
||||
```
|
||||
|
||||
**Step 1: Extract Thread Tensor Descriptor**
|
||||
|
||||
The tile distribution contains a `ys_to_d_descriptor` that maps:
|
||||
- **Y dimensions** (logical tile coordinates, e.g., M, K)
|
||||
- **D dimension** (per-thread register index, linearized)
|
||||
|
||||
```cpp
|
||||
using ThreadTensorDesc =
|
||||
decltype(StaticTileDistribution{}.get_ys_to_d_descriptor());
|
||||
```
|
||||
|
||||
**Step 2: Calculate Per-Thread Buffer Size**
|
||||
|
||||
```cpp
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
- 256×32 tile distributed across 256 threads
|
||||
- Each thread owns 32 elements (one row)
|
||||
- `thread_buffer_size = 32` (for PackedSize=1)
|
||||
|
||||
**Step 3: Allocate Thread Buffer**
|
||||
|
||||
```cpp
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
```
|
||||
|
||||
This is essentially:
|
||||
```cpp
|
||||
fp16_t data[32]; // Per-thread register array (VGPRs)
|
||||
```
|
||||
|
||||
### Usage in Load/Store Operations
|
||||
|
||||
**Load from DRAM:**
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread queries the tile distribution: "Which elements do I own?"
|
||||
2. Thread 0 learns it owns A[0,0:31]
|
||||
3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]`
|
||||
4. All 256 threads do this **in parallel**
|
||||
|
||||
**Store to LDS:**
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread reads from its `a_block_tile.thread_buf_`
|
||||
2. Thread 0 writes A[0,0:31] from its registers to LDS
|
||||
3. All 256 threads do this **in parallel**
|
||||
4. After `block_sync_lds()`, the entire tile is in shared LDS
|
||||
|
||||
### Distributed Indexing
|
||||
|
||||
The `static_distributed_tensor` supports compile-time indexing:
|
||||
|
||||
```cpp
|
||||
// Access using distributed indices
|
||||
auto value = a_block_tile(tile_distributed_index<i, j>{});
|
||||
```
|
||||
|
||||
Internally:
|
||||
1. Convert distributed index → Y index (logical tile coordinates)
|
||||
2. Calculate buffer offset using `ThreadTensorDesc`
|
||||
3. Access `thread_buf_[offset]`
|
||||
|
||||
All of this happens **at compile time** with zero runtime overhead!
|
||||
|
||||
### Why This Design?
|
||||
|
||||
**Benefits:**
|
||||
1. **Parallel Memory Access**: All threads load/store simultaneously
|
||||
2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once)
|
||||
3. **Zero Overhead**: All indexing resolved at compile time
|
||||
4. **Type Safety**: Distribution mismatch caught at compile time
|
||||
5. **Register Pressure**: Compiler knows exact VGPR usage
|
||||
|
||||
**Trade-offs:**
|
||||
- Requires compile-time tile sizes
|
||||
- Distribution must be static
|
||||
- More complex type system
|
||||
|
||||
### Memory Hierarchy Summary
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) │
|
||||
│ Full matrices A, B, C │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ load_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Per-Thread Registers) │
|
||||
│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │
|
||||
│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │
|
||||
│ ... │
|
||||
│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │
|
||||
│ │
|
||||
│ ← static_distributed_tensor manages this distribution │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ store_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) │
|
||||
│ Entire block tile (256×32) │
|
||||
│ Accessible to all threads in block │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time.
|
||||
|
||||
|
||||
|
||||
17
tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_NAIVE_GEMM "tile_tutorial_naive_gemm")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_NAIVE_GEMM}")
|
||||
|
||||
add_executable(${EXAMPLE_NAIVE_GEMM} EXCLUDE_FROM_ALL practice_gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_NAIVE_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_NAIVE_GEMM} PRIVATE ${EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_NAIVE_GEMM})
|
||||
618
tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
618
tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,618 @@
|
||||
# Host-Level Pipeline: Orchestrating Block-Level GEMM
|
||||
|
||||
This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation.
|
||||
|
||||
## Overview
|
||||
|
||||
The host-level pipeline is responsible for:
|
||||
1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C
|
||||
2. **Block-to-tile mapping**: Assigning each thread block to a specific tile
|
||||
3. **Creating tile windows**: Establishing sliding windows over tensor views
|
||||
4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM
|
||||
5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
// 1. Calculate problem dimensions and tile coverage
|
||||
// 2. Map thread block to tile coordinates
|
||||
// 3. Create tile windows over A and B
|
||||
// 4. Call block-level pipeline to compute
|
||||
// 5. Store result to C
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Calculate Problem Dimensions and Tile Coverage
|
||||
|
||||
```cpp
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{}); // 256
|
||||
const auto NPerBlock = BlockTile::at(number<1>{}); // 128
|
||||
const auto KPerBlock = BlockTile::at(number<2>{}); // 32
|
||||
|
||||
// Number of block tiles needed to cover C matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Extract problem dimensions** from tensor descriptors:
|
||||
- `M = 512`: Rows in A and C
|
||||
- `N = 256`: Columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
|
||||
2. **Get block tile sizes** from the `BlockTile` configuration:
|
||||
- `MPerBlock = 256`: Each block processes 256 rows
|
||||
- `NPerBlock = 128`: Each block processes 128 columns
|
||||
- `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration
|
||||
|
||||
3. **Calculate tile coverage**:
|
||||
- `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction
|
||||
- `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction
|
||||
- **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**!
|
||||
|
||||
### Visual Representation:
|
||||
|
||||
```
|
||||
Matrix C (512 × 256):
|
||||
┌──────────────────────┬──────────────────────┐
|
||||
│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 0 │ Block 1 │
|
||||
│ │ │
|
||||
├──────────────────────┼──────────────────────┤
|
||||
│ Tile (1,0) │ Tile (1,1) │
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 2 │ Block 3 │
|
||||
│ │ │
|
||||
└──────────────────────┴──────────────────────┘
|
||||
↑
|
||||
num_tile_m = 2
|
||||
|
||||
Total blocks needed = 2 × 2 = 4 blocks
|
||||
|
||||
Each block computes one 256×128 tile of the output matrix C.
|
||||
```
|
||||
|
||||
### How Blocks Cover Matrices A and B:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬──────┐ ┌─────────────┬──────┐
|
||||
│ Block 0,2 │ K │ │ Block 0,1 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 0-255 │ │ │ 0-127 │ │
|
||||
├─────────────┼──────┤ ├─────────────┼──────┤
|
||||
│ Block 1,3 │ K │ │ Block 2,3 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 256-511 │ │ │ 128-255 │ │
|
||||
└─────────────┴──────┘ └─────────────┴──────┘
|
||||
256 rows 64 cols 128 rows 64 cols
|
||||
|
||||
Each block needs to iterate over K dimension (64/32 = 2 iterations)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Map Thread Block to Tile Coordinates
|
||||
|
||||
```cpp
|
||||
// Get block id (0 to total_blocks - 1)
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
// Map block id to 2D tile coordinates
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate
|
||||
const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates.
|
||||
|
||||
### The `MakeBlock2TileMap` Function:
|
||||
|
||||
```cpp
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
// Create a merge transform: (N0, M0) → linear index
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
// Convert linear block_id back to 2D coordinates
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
// Return (m_idx, n_idx) - note the swap!
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### In Our Example (2×2 Grid):
|
||||
|
||||
```cpp
|
||||
// Block 0:
|
||||
id_block = 0
|
||||
tile_id = block2tile(0) = (0, 0) // Top-left tile
|
||||
tile_id_m = 0, tile_id_n = 0
|
||||
|
||||
// Block 1:
|
||||
id_block = 1
|
||||
tile_id = block2tile(1) = (1, 0) // Bottom-left tile
|
||||
tile_id_m = 1, tile_id_n = 0
|
||||
|
||||
// Block 2:
|
||||
id_block = 2
|
||||
tile_id = block2tile(2) = (0, 1) // Top-right tile
|
||||
tile_id_m = 0, tile_id_n = 1
|
||||
|
||||
// Block 3:
|
||||
id_block = 3
|
||||
tile_id = block2tile(3) = (1, 1) // Bottom-right tile
|
||||
tile_id_m = 1, tile_id_n = 1
|
||||
```
|
||||
|
||||
**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing!
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Calculate Tile Origin and Create Tile Windows
|
||||
|
||||
```cpp
|
||||
// Calculate the starting position of this tile in the global matrix
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128
|
||||
|
||||
// Create tile windows over A and B tensor views
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, // Tensor view over A
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // Window size: 256×32
|
||||
{tile_origin_m, 0} // Origin: varies by block
|
||||
);
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, // Tensor view over B
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), // Window size: 128×32
|
||||
{tile_origin_n, 0} // Origin: varies by block
|
||||
);
|
||||
```
|
||||
|
||||
### Tile Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0 (Tile 0,0):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 1 (Tile 1,0):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 2 (Tile 0,1):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
|
||||
// Block 3 (Tile 1,1):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
```
|
||||
|
||||
### What are Tile Windows?
|
||||
|
||||
A **tile window** is a **sliding window** over a larger tensor view. It:
|
||||
- Defines a **rectangular region** within the tensor
|
||||
- Has a **fixed size** (e.g., 256×32 for A)
|
||||
- Has an **origin** (starting position)
|
||||
- Can be **moved** to access different regions
|
||||
### Visual Representation (Block 0 Example):
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │
|
||||
│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │
|
||||
│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │
|
||||
│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │
|
||||
│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │
|
||||
│ │ │ ├─────────────┼─────────────┤
|
||||
├─────────────┼─────────────┤ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
└─────────────┴─────────────┘ └─────────────┴─────────────┘
|
||||
Origin: (0, 0) Origin: (0, 0)
|
||||
Covers rows 0-255 Covers rows 0-127
|
||||
Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration)
|
||||
```
|
||||
|
||||
**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration.
|
||||
|
||||
### Tile Window Properties:
|
||||
|
||||
```cpp
|
||||
// Tile window structure (conceptual):
|
||||
struct tile_window {
|
||||
TensorView& tensor_view; // Reference to underlying tensor
|
||||
Tuple window_lengths; // Size of the window (256, 32)
|
||||
MultiIndex window_origin; // Starting position (0, 0)
|
||||
|
||||
// Can move the window:
|
||||
void move(MultiIndex step); // Shift window by step
|
||||
|
||||
// Access data through the window:
|
||||
auto load(); // Load data from windowed region
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
### Tile Window Movement: Iterating Over K Dimension
|
||||
|
||||
In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K
|
||||
│ ┃ 256×32 ┃ │ ║ 256×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ │ │
|
||||
│ Block 1's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
|
||||
Matrix B (256 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │
|
||||
│ ┃ 128×32 ┃ │ ║ 128×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ Block 2's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
```
|
||||
|
||||
### How Windows Move (Conceptual - handled by block pipeline):
|
||||
|
||||
```cpp
|
||||
// Iteration 0:
|
||||
a_block_window origin: (tile_origin_m, 0) // K columns 0-31
|
||||
b_block_window origin: (tile_origin_n, 0) // K columns 0-31
|
||||
// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31]
|
||||
|
||||
// Move windows to next K position:
|
||||
move_tile_window(a_block_window, {0, 32});
|
||||
move_tile_window(b_block_window, {0, 32});
|
||||
|
||||
// Iteration 1:
|
||||
a_block_window origin: (tile_origin_m, 32) // K columns 32-63
|
||||
b_block_window origin: (tile_origin_n, 32) // K columns 32-63
|
||||
// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63]
|
||||
|
||||
// Final result:
|
||||
// C_tile = C_partial_0 + C_partial_1
|
||||
```
|
||||
|
||||
**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C.
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Delegate to Block-Level Pipeline
|
||||
|
||||
```cpp
|
||||
// Get the block-level pipeline from policy
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
// Calculate number of K iterations needed
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2
|
||||
|
||||
// Allocate shared memory (LDS) for block-level computation
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
|
||||
// Call block-level pipeline to compute C tile
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation
|
||||
2. **Calculate K iterations**: How many times to iterate over the K dimension
|
||||
- In our example: `K=64, KPerBlock=32` → **2 iterations**
|
||||
- Each iteration processes 32 elements of the K dimension
|
||||
- Results are accumulated across iterations
|
||||
|
||||
3. **Allocate shared memory**:
|
||||
- `__shared__` declares memory shared by all threads in the block
|
||||
- `GetStaticLDSSize()` returns the required size in bytes
|
||||
- This memory is used for:
|
||||
- Staging data from DRAM → LDS
|
||||
- Cooperative loading by threads
|
||||
- Fast access during computation
|
||||
|
||||
4. **Execute block pipeline**:
|
||||
- Takes A and B tile windows as input
|
||||
- Performs the GEMM computation: `C_tile = A_tile × B_tile`
|
||||
- Returns result in `c_block_tile` (stored in VGPRs - registers)
|
||||
|
||||
### Memory Hierarchy During Computation:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) - Slowest, Largest │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A matrix │ │ B matrix │ │ C matrix │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load ↑ store
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A_tile │ │ B_tile │ ← Staged here │
|
||||
│ │ (p_smem) │ │ (p_smem) │ │
|
||||
│ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ c_block_tile (accumulated result) │ │
|
||||
│ │ Computation happens here using MFMA instructions │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Block Pipeline Responsibilities:
|
||||
|
||||
The block pipeline (called here) will:
|
||||
1. Load A and B tiles from DRAM → LDS (cooperative loading)
|
||||
2. Distribute work among warps
|
||||
3. Each warp loads its portion from LDS → VGPRs
|
||||
4. Perform MFMA operations: `C += A × B`
|
||||
5. Accumulate results in VGPRs
|
||||
6. Return final `c_block_tile` in registers
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Store Results to DRAM
|
||||
|
||||
```cpp
|
||||
// Create a tile window over C for writing results
|
||||
auto c_window = make_tile_window(
|
||||
c_dram, // Tensor view over C
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}), // Window size: 256×128
|
||||
{tile_origin_m, tile_origin_n} // Origin: varies by block
|
||||
);
|
||||
|
||||
// Store computed tile from VGPRs to DRAM
|
||||
store_tile(c_window, c_block_tile);
|
||||
```
|
||||
|
||||
### C Window Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0: Writes to top-left tile
|
||||
c_window origin: (0, 0) → writes to C[0:255, 0:127]
|
||||
|
||||
// Block 1: Writes to bottom-left tile
|
||||
c_window origin: (256, 0) → writes to C[256:511, 0:127]
|
||||
|
||||
// Block 2: Writes to top-right tile
|
||||
c_window origin: (0, 128) → writes to C[0:255, 128:255]
|
||||
|
||||
// Block 3: Writes to bottom-right tile
|
||||
c_window origin: (256, 128) → writes to C[256:511, 128:255]
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Create C tile window**:
|
||||
- Size: 256×128 (matches our block tile size)
|
||||
- Origin: Varies by block - each block writes to its assigned region
|
||||
- This window defines **where** to write the results
|
||||
|
||||
2. **Store tile to DRAM**:
|
||||
- `c_block_tile`: Computed results in VGPRs (registers)
|
||||
- `c_window`: Destination window in DRAM
|
||||
- `store_tile()`: Efficiently writes data from registers → DRAM
|
||||
|
||||
### The `store_tile` Function:
|
||||
|
||||
Recall from our earlier discussion, `store_tile` does:
|
||||
|
||||
```cpp
|
||||
template <typename TileWindow, typename DistributedTensor>
|
||||
void store_tile(TileWindow& tile_window_tmp,
|
||||
const DistributedTensor& dstr_tensor)
|
||||
{
|
||||
// 1. Extract tile distribution from distributed tensor
|
||||
using TileDstr = typename DistributedTensor::TileDistribution;
|
||||
|
||||
// 2. Upgrade simple tile window to one with distribution
|
||||
auto tile_window = make_tile_window(
|
||||
tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
TileDstr{} // Add distribution info
|
||||
);
|
||||
|
||||
// 3. Store using vectorized writes
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
```
|
||||
|
||||
### Memory Flow:
|
||||
|
||||
```
|
||||
VGPRs (Registers) DRAM (Global Memory)
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ c_block_tile │ │ C matrix │
|
||||
│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │
|
||||
│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │
|
||||
│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │
|
||||
│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │
|
||||
│ └───┴───┴───┴───┘ │ │ │ │ │
|
||||
│ Distributed across │ │ └───────────────┘ │
|
||||
│ threads/warps │ │ Origin: (0, 0) │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
|
||||
Each thread writes its portion using vector stores (e.g., float4)
|
||||
```
|
||||
|
||||
### Store Optimization:
|
||||
|
||||
The `store_tile` function:
|
||||
- Uses **vectorized stores** (write multiple elements at once)
|
||||
- Ensures **coalesced memory access** (adjacent threads write adjacent memory)
|
||||
- Respects **tile distribution** (each thread knows what data it owns)
|
||||
- Handles **out-of-bounds** checking (for partial tiles at boundaries)
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow Visualization
|
||||
|
||||
Let's trace the complete flow for **Block 0** (other blocks follow the same pattern):
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Calculate Tile Coverage │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ M=512, N=256, K=64 │ │
|
||||
│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │
|
||||
│ │ num_tile_m = ceil(512/256) = 2 │ │
|
||||
│ │ num_tile_n = ceil(256/128) = 2 │ │
|
||||
│ │ Total blocks needed = 2 × 2 = 4 blocks │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Map Block to Tile (Block 0 example) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Block ID: 0 │ │
|
||||
│ │ Tile coordinates: (0, 0) - top-left tile │ │
|
||||
│ │ Tile origin: (0, 0) │ │
|
||||
│ │ │ │
|
||||
│ │ (Blocks 1,2,3 get different tile coordinates) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Create Tile Windows │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ a_block_window: 256×32 starting at (0,0) over A │ │
|
||||
│ │ b_block_window: 128×32 starting at (0,0) over B │ │
|
||||
│ │ Windows initially cover K columns 0-31 │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Execute Block Pipeline (2 K iterations) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Allocate shared memory (LDS) │ │
|
||||
│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 0 (K=0-31): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │
|
||||
│ │ └─ Move windows: {0, 32} │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 1 (K=32-63): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │
|
||||
│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │
|
||||
│ │ │ │
|
||||
│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 5: Store Results │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Create c_window: 256×128 starting at (0,0) over C │ │
|
||||
│ │ store_tile(c_window, c_block_tile) │ │
|
||||
│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │
|
||||
│ │ │ │
|
||||
│ │ Block 0 writes to C[0:255, 0:127] │ │
|
||||
│ │ (Other blocks write to their respective regions) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
All 4 blocks execute in parallel, each computing its assigned 256×128 tile!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. **Tile Coverage**
|
||||
- Determines how many thread blocks are needed
|
||||
- Each block processes one tile of the output matrix C
|
||||
- Calculated as `ceil(dimension / tile_size)`
|
||||
|
||||
### 2. **Block-to-Tile Mapping**
|
||||
- Maps linear block ID to 2D tile coordinates
|
||||
- Uses column-major ordering for better memory coalescing
|
||||
- Each block knows which tile it's responsible for
|
||||
|
||||
### 3. **Tile Windows**
|
||||
- **Sliding windows** over larger tensor views
|
||||
- Define a rectangular region with fixed size and movable origin
|
||||
- Provide efficient, structured access to tensor data
|
||||
- Can be moved to access different regions (e.g., for K iterations)
|
||||
|
||||
### 4. **Memory Hierarchy**
|
||||
- **DRAM (Global)**: Largest, slowest - stores full matrices
|
||||
- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access
|
||||
- **VGPRs (Registers)**: Smallest, fastest - performs computation
|
||||
|
||||
### 5. **Data Flow**
|
||||
```
|
||||
DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM
|
||||
↑ ↓
|
||||
A, B matrices C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore:
|
||||
- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed
|
||||
- **Warp-level pipeline**: How warps perform MFMA operations
|
||||
- **Memory optimization**: LDS usage, bank conflicts, coalescing
|
||||
|
||||
The host level provides the **orchestration**, while the block and warp levels provide the **execution**!
|
||||
|
||||
464
tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
464
tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
@@ -0,0 +1,464 @@
|
||||
# PracticeGemmKernel: Understanding the Kernel Entry Point
|
||||
|
||||
This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views.
|
||||
|
||||
## Overview
|
||||
|
||||
The `PracticeGemmKernel` is a templated struct that:
|
||||
1. Takes raw device memory pointers for matrices A, B, and C
|
||||
2. Wraps them into **tensor views** - logical, structured views over physical memory
|
||||
3. Dispatches to the host-level pipeline for computation
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
// Step 1: Create tensor views over raw memory
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
// Step 2: Dispatch to host-level pipeline
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What are Tensor Views?
|
||||
|
||||
A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor.
|
||||
|
||||
### Key Components of a Tensor View:
|
||||
|
||||
1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers)
|
||||
2. **Raw Pointer**: Points to the actual data in memory
|
||||
3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A)
|
||||
4. **Strides**: How to navigate through memory to access elements
|
||||
5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction
|
||||
6. **Guaranteed Vector Stride**: The stride of those vectorizable elements
|
||||
|
||||
---
|
||||
|
||||
## The Memory Abstraction Hierarchy
|
||||
|
||||
CK Tile uses a three-layer abstraction to go from raw memory to structured tensors:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Layer 3: TENSOR VIEW │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ • Logical multi-dimensional structure │ │
|
||||
│ │ • Shape: (M, K) = (256, 32) │ │
|
||||
│ │ • Strides: (32, 1) for row-major layout │ │
|
||||
│ │ • Provides: operator[], coordinate-based access │ │
|
||||
│ │ • Knows: How to map (i,j) → linear offset │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 2: BUFFER VIEW │ │
|
||||
│ │ ┌─────────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ • Linear view of memory │ │ │
|
||||
│ │ │ • Pointer: p_data_ → device memory │ │ │
|
||||
│ │ │ • Size: Total number of elements │ │ │
|
||||
│ │ │ • Address space: global/LDS/generic │ │ │
|
||||
│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │
|
||||
│ │ └─────────────────────────────────────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 1: RAW PHYSICAL MEMORY │ │
|
||||
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │
|
||||
│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │
|
||||
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │
|
||||
│ │ ↑ │ │
|
||||
│ │ p_a (raw pointer from hipMalloc) │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `make_naive_tensor_view`
|
||||
|
||||
Let's break down the function call for matrix A:
|
||||
|
||||
```cpp
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer to device memory
|
||||
make_tuple(M, K), // Shape: (256, 32)
|
||||
make_tuple(stride_a, 1), // Strides: (32, 1) - row-major
|
||||
number<8>{}, // Guaranteed vector length
|
||||
number<1>{} // Guaranteed vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Function Signature:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
// Step 1: Create tensor descriptor (shape + stride information)
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
// Step 2: Create buffer view (pointer + size + address space)
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
// Step 3: Combine into tensor view
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Parameter Breakdown
|
||||
|
||||
### 1. **Template Parameter: `address_space_enum::global`**
|
||||
|
||||
Specifies where the memory lives:
|
||||
- `global`: GPU global memory (DRAM) - slowest but largest
|
||||
- `lds`: Local Data Share (shared memory) - fast, limited size
|
||||
- `generic`: Generic address space
|
||||
- `vgpr`: Vector General Purpose Registers - fastest, smallest
|
||||
|
||||
In our case, `global` means the data is in GPU DRAM.
|
||||
|
||||
### 2. **`p_a` - Raw Pointer**
|
||||
|
||||
The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data.
|
||||
|
||||
### 3. **`make_tuple(M, K)` - Shape/Lengths**
|
||||
|
||||
Defines the logical dimensions of the tensor:
|
||||
- For matrix A: `(256, 32)` means 256 rows, 32 columns
|
||||
- This is the **logical view**, independent of how data is physically laid out
|
||||
|
||||
### 4. **`make_tuple(stride_a, 1)` - Strides**
|
||||
|
||||
Defines how to navigate through memory:
|
||||
- **Stride for dimension 0 (rows)**: `stride_a = K = 32`
|
||||
- To move to the next row, skip 32 elements
|
||||
- **Stride for dimension 1 (columns)**: `1`
|
||||
- To move to the next column, skip 1 element
|
||||
|
||||
**Row-major layout example:**
|
||||
```
|
||||
Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...]
|
||||
↑ ↑
|
||||
Row 0 starts here Row 1 starts here (offset = 32)
|
||||
|
||||
To access element A[i][j]:
|
||||
offset = i * stride_a + j * 1
|
||||
= i * 32 + j
|
||||
```
|
||||
|
||||
### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length**
|
||||
|
||||
This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."**
|
||||
|
||||
#### Why is this important?
|
||||
|
||||
Modern GPUs can load multiple elements in one instruction (vectorized loads):
|
||||
- `float4`: Load 4 floats at once
|
||||
- `float8`: Load 8 floats at once (if supported)
|
||||
|
||||
By specifying `number<8>{}`, we're telling the system:
|
||||
- "You can safely use vector loads of up to 8 elements"
|
||||
- "The memory alignment and layout support this"
|
||||
|
||||
**Example:**
|
||||
```cpp
|
||||
// Without vectorization (slow):
|
||||
for (int j = 0; j < 8; j++) {
|
||||
data[j] = memory[offset + j]; // 8 separate loads
|
||||
}
|
||||
|
||||
// With vectorization (fast):
|
||||
float8 vec = *reinterpret_cast<float8*>(&memory[offset]); // 1 load!
|
||||
```
|
||||
|
||||
### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride**
|
||||
|
||||
This specifies the **stride between consecutive vectorizable elements** in the last dimension.
|
||||
|
||||
- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)"
|
||||
- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively
|
||||
|
||||
**Why does this matter?**
|
||||
|
||||
For efficient vectorized loads, elements must be:
|
||||
1. **Contiguous** (stride = 1) ✓
|
||||
2. **Aligned** properly in memory
|
||||
3. **Within the same cache line** (ideally)
|
||||
|
||||
If the stride were `2`, it would mean:
|
||||
```
|
||||
A[i][0] is at offset 0
|
||||
A[i][1] is at offset 2 (not 1!)
|
||||
A[i][2] is at offset 4
|
||||
```
|
||||
This would prevent efficient vectorization.
|
||||
|
||||
---
|
||||
|
||||
## What is a Buffer View?
|
||||
|
||||
A **buffer view** is the middle layer between raw memory and tensor view. It provides:
|
||||
|
||||
### Core Responsibilities:
|
||||
|
||||
1. **Memory Management**
|
||||
- Holds the raw pointer: `T* p_data_`
|
||||
- Tracks buffer size: `BufferSizeType buffer_size_`
|
||||
- Knows the address space: `global`, `lds`, etc.
|
||||
|
||||
2. **Vectorized Access**
|
||||
```cpp
|
||||
template <typename VectorType>
|
||||
CK_TILE_DEVICE VectorType get(index_t offset);
|
||||
```
|
||||
- Provides efficient vector loads/stores
|
||||
- Handles alignment requirements
|
||||
|
||||
3. **Bounds Checking** (optional)
|
||||
```cpp
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto get(index_t i, index_t linear_offset);
|
||||
```
|
||||
- Can optionally check if access is within bounds
|
||||
- Returns invalid value (default 0) for out-of-bounds access
|
||||
|
||||
4. **Address Space Awareness**
|
||||
- Uses different load/store instructions based on address space
|
||||
- Global memory: `global_load`, `global_store`
|
||||
- LDS: `ds_read`, `ds_write`
|
||||
|
||||
### Buffer View Structure:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
struct buffer_view
|
||||
{
|
||||
T* p_data_; // Raw pointer
|
||||
BufferSizeType buffer_size_; // Total elements
|
||||
remove_cvref_t<T> invalid_element_value_; // Value for OOB access
|
||||
|
||||
// Access operators
|
||||
const T& operator[](index_t i) const; // Read
|
||||
T& operator()(index_t i); // Write
|
||||
|
||||
// Vectorized access
|
||||
template <typename VectorType>
|
||||
VectorType get(index_t offset);
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Visual Example: Matrix A Memory Layout
|
||||
|
||||
Let's visualize how matrix A (256×32, fp16) is organized:
|
||||
|
||||
### Raw Physical Memory (Linear):
|
||||
```
|
||||
GPU DRAM Address Space:
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Byte 0 │
|
||||
│ ↓ │
|
||||
│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │
|
||||
│ ↑ ↑ │
|
||||
│ Row 0 (32 elements) Row 1 (32 elements) │
|
||||
│ │
|
||||
│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↑
|
||||
p_a (raw pointer)
|
||||
```
|
||||
|
||||
### Buffer View Layer:
|
||||
```
|
||||
buffer_view<address_space_enum::global, fp16_t, ...>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ p_data_ = p_a │
|
||||
│ buffer_size_ = 256 × 32 = 8,192 elements │
|
||||
│ address_space = global (DRAM) │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Linear indexing: buffer_view[i] → element at offset i │
|
||||
│ • Vectorized loads: get<float4>(offset) → load 4 fp16s at once│
|
||||
│ • Bounds checking: is offset < buffer_size_? │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Tensor View Layer:
|
||||
```
|
||||
tensor_view<buffer_view, tensor_descriptor>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Shape: (256, 32) │
|
||||
│ Strides: (32, 1) │
|
||||
│ Guaranteed vector length: 8 │
|
||||
│ Guaranteed vector stride: 1 │
|
||||
│ │
|
||||
│ Logical 2D View: │
|
||||
│ Col: 0 1 2 ... 31 │
|
||||
│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│
|
||||
│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │
|
||||
│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │
|
||||
│ ... │
|
||||
│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Multi-dimensional indexing: A[i][j] │
|
||||
│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │
|
||||
│ • Tile window creation: Extract sub-tensors │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow: Raw Memory → Tensor View
|
||||
|
||||
Let's trace the complete transformation for matrix A:
|
||||
|
||||
### Step 1: Kernel Launch (Host Side)
|
||||
```cpp
|
||||
// On host: Allocate device memory
|
||||
hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer
|
||||
|
||||
// Launch kernel
|
||||
kernel<<<grid, block>>>(p_a, p_b, p_c, M, N, K, ...);
|
||||
```
|
||||
|
||||
### Step 2: Inside Kernel (Device Side)
|
||||
```cpp
|
||||
// Receive raw pointer
|
||||
const fp16_t* p_a; // Points to GPU DRAM
|
||||
|
||||
// Step 2a: Create tensor descriptor
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(256, 32), // Shape
|
||||
make_tuple(32, 1), // Strides
|
||||
number<8>{}, // Vector length
|
||||
number<1>{} // Vector stride
|
||||
);
|
||||
// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8"
|
||||
|
||||
// Step 2b: Create buffer view
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer
|
||||
256 * 32 // Total elements
|
||||
);
|
||||
// buffer_view now wraps p_a with size and address space info
|
||||
|
||||
// Step 2c: Create tensor view
|
||||
auto a_dram = tensor_view{buffer_view, desc};
|
||||
// a_dram now provides structured, multi-dimensional access to p_a
|
||||
```
|
||||
|
||||
### Step 3: Using the Tensor View
|
||||
```cpp
|
||||
// Access element A[i][j]
|
||||
auto value = a_dram[make_tuple(i, j)];
|
||||
|
||||
// Create a tile window (sub-tensor)
|
||||
auto tile = make_tile_window(
|
||||
a_dram,
|
||||
make_tuple(16, 16), // 16×16 tile
|
||||
make_tuple(0, 0) // Starting at origin
|
||||
);
|
||||
|
||||
// Load tile into registers with vectorization
|
||||
auto tile_data = load_tile(tile); // Uses vector loads internally!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why This Abstraction?
|
||||
|
||||
### Benefits:
|
||||
|
||||
1. **Type Safety**: Can't accidentally access wrong dimensions
|
||||
2. **Performance**: Compiler knows about vectorization opportunities
|
||||
3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers)
|
||||
4. **Maintainability**: Logical structure separate from physical layout
|
||||
5. **Optimization**: Guaranteed vector properties enable aggressive optimizations
|
||||
|
||||
### Example: Without Tensor Views (Manual Indexing)
|
||||
```cpp
|
||||
// Ugly, error-prone, hard to optimize:
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j];
|
||||
// Hope the compiler vectorizes this? 🤞
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: With Tensor Views (Clean, Optimized)
|
||||
```cpp
|
||||
// Clean, safe, automatically vectorized:
|
||||
auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin);
|
||||
auto tile_data = load_tile(tile); // Vectorized loads guaranteed!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction:
|
||||
|
||||
1. **Raw Memory**: Linear array of bytes in GPU DRAM
|
||||
2. **Buffer View**: Adds size, address space, and vectorized access
|
||||
3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing
|
||||
|
||||
This abstraction enables:
|
||||
- ✅ Clean, readable code
|
||||
- ✅ Type-safe multi-dimensional access
|
||||
- ✅ Automatic vectorization
|
||||
- ✅ Flexible memory space handling
|
||||
- ✅ Efficient tile-based computation
|
||||
|
||||
The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation!
|
||||
|
||||
150
tutorial/ck_tile/gemm/01_naive_gemm/README.md
Normal file
150
tutorial/ck_tile/gemm/01_naive_gemm/README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# CK Tile Practice GEMM Example
|
||||
|
||||
This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system.
|
||||
|
||||
## CK Tile API Structure
|
||||
|
||||
In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**:
|
||||
|
||||
1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices
|
||||
2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads
|
||||
3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue)
|
||||
|
||||
## Overview
|
||||
|
||||
This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing:
|
||||
|
||||
- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch
|
||||
- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM
|
||||
- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory
|
||||
- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation
|
||||
|
||||
## Problem Setup and Data Flow
|
||||
|
||||
### Problem Size Configuration
|
||||
We set the problem size using the M, N and K variables:
|
||||
```cpp
|
||||
ck_tile::index_t M = 1024; // Number of rows in A and C
|
||||
ck_tile::index_t N = 512; // Number of columns in B and C
|
||||
ck_tile::index_t K = 256; // Number of columns in A, rows in B
|
||||
```
|
||||
|
||||
### Host Matrix Creation
|
||||
Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory:
|
||||
```cpp
|
||||
// Host tensors with proper strides
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides); // M × K
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides); // N × K
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides); // M × N
|
||||
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
// Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
a_device.ToDevice(a_host.data());
|
||||
```
|
||||
|
||||
### PracticeGemmShape Configuration
|
||||
A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile:
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave
|
||||
```
|
||||
- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix.
|
||||
|
||||
|
||||
- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile.
|
||||
- BlockTiles are further subdivided into WarpTiles.
|
||||
- WarpTiles over A and B similarly work together to calculate the WarpTile of C.
|
||||
|
||||
### Problem and Policy Composition
|
||||
```cpp
|
||||
// A Problem is composed from Shape and info about the data
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
// A Policy is created describing data-to-thread mapping
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
// A Kernel is then composed of Problem and Policy
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
```
|
||||
|
||||
### Kernel Launch
|
||||
`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`:
|
||||
```cpp
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCU>(
|
||||
gemm_kernel{}, // Kernel composed of Problem + Policy
|
||||
kGridSize, // Grid dimensions
|
||||
kBlockSize, // Block dimensions
|
||||
0, // Dynamic shared memory
|
||||
// Kernel arguments: device buffers and problem dimensions
|
||||
a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(),
|
||||
M, N, K, stride_a, stride_b, stride_c));
|
||||
```
|
||||
|
||||
### Result Verification
|
||||
The results from the kernel are compared with results from CPU based computation function:
|
||||
```cpp
|
||||
// CPU reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
|
||||
|
||||
// Device results
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
// Verify correctness
|
||||
bool pass = ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
```
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
The main program (`practice_gemm.cpp`) is the entry point for the runtime flow:
|
||||
|
||||
```cpp
|
||||
int main()
|
||||
{
|
||||
// 1. Define data types and problem sizes
|
||||
using ADataType = ck_tile::half_t;
|
||||
ck_tile::index_t M = 2048, N = 1024, K = 512;
|
||||
|
||||
// 2. Create host tensors and initialize
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
// 3. Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
|
||||
// 4. Configure tile shapes
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// 5. Launch kernel
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
float ave_time = ck_tile::launch_kernel(/*...*/);
|
||||
|
||||
// 6. Verify results
|
||||
bool pass = verify_results(a_host, b_host, c_host);
|
||||
|
||||
// 7. Print performance metrics
|
||||
print_performance_metrics(ave_time, M, N, K);
|
||||
}
|
||||
```
|
||||
|
||||
## Building and Running
|
||||
|
||||
```bash
|
||||
# From composable_kernel root directory
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_practice_gemm -j
|
||||
|
||||
# Run with sample sizes
|
||||
./bin/tile_example_practice_gemm
|
||||
```
|
||||
This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework.
|
||||
312
tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md
Normal file
312
tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md
Normal file
@@ -0,0 +1,312 @@
|
||||
# Tile Distribution: Mapping Threads to Data
|
||||
|
||||
## Overview
|
||||
|
||||
**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks.
|
||||
|
||||
## The Problem
|
||||
|
||||
Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine:
|
||||
- Which threads load which elements.
|
||||
- How the threads are organized into warps.
|
||||
- The number of times each warp repeats its pattern.
|
||||
- The number of elements each thread can load in a single vector instruction.
|
||||
|
||||
---
|
||||
|
||||
## Bottom-Up Construction Approach
|
||||
|
||||
### Step 1: Determine K Dimension Layout
|
||||
|
||||
**Start with the innermost dimension (K) for memory coalescing:**
|
||||
|
||||
```cpp
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load)
|
||||
constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension
|
||||
```
|
||||
|
||||
**Example (with fp16):**
|
||||
- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction
|
||||
- `kKPerBlock = 32`
|
||||
- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension
|
||||
|
||||
**Visual:**
|
||||
```
|
||||
K dimension (32 elements):
|
||||
Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31]
|
||||
K1=8 K1=8 K1=8 K1=8
|
||||
├──────────────────────────────────────────────────────────────┤
|
||||
K0=4 threads
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Determine M Dimension Layout
|
||||
|
||||
**Now partition the M dimension hierarchically:**
|
||||
|
||||
#### Level 1: Threads per Warp in M (M2)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
```
|
||||
|
||||
- Warp size = 64 threads
|
||||
- K dimension already uses `K0 = 4` threads per row
|
||||
- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension
|
||||
|
||||
**Visual (Single Warp):**
|
||||
```
|
||||
K dimension (4 threads)
|
||||
┌─────┬─────┬─────┬─────┐
|
||||
0 │ T0 │ T1 │ T2 │ T3 │
|
||||
1 │ T4 │ T5 │ T6 │ T7 │
|
||||
2 │ T8 │ T9 │ T10 │ T11 │
|
||||
M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows
|
||||
...│ ... │ ... │ ... │ ... │ (M2=16)
|
||||
15 │ T60 │ T61 │ T62 │ T63 │
|
||||
└─────┴─────┴─────┴─────┘
|
||||
One Warp = 64 threads
|
||||
```
|
||||
|
||||
#### Level 2: Warps per Block (M1)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
```
|
||||
|
||||
- `kBlockSize = 256` threads per block
|
||||
- `M1 = 256 / 64 = 4` → We have 4 warps per block
|
||||
|
||||
**Visual (4 Warps):**
|
||||
```
|
||||
Warp 0 (rows 0-15)
|
||||
Warp 1 (rows 16-31)
|
||||
Warp 2 (rows 32-47)
|
||||
Warp 3 (rows 48-63)
|
||||
↑
|
||||
M1 = 4 warps cover 64 rows total
|
||||
```
|
||||
|
||||
#### Level 3: Repetitions (M0)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
```
|
||||
|
||||
- `kMPerBlock = 256` rows to cover
|
||||
- `M2 * M1 = 16 * 4 = 64` rows covered by all warps
|
||||
- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times
|
||||
|
||||
**Visual (Complete Block):**
|
||||
```
|
||||
┌──────────────┐
|
||||
│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ...
|
||||
│ (rows 0-63) │
|
||||
├──────────────┤
|
||||
│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ...
|
||||
│ (rows 64-127)│
|
||||
├──────────────┤
|
||||
│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ...
|
||||
│(rows 128-191)│
|
||||
├──────────────┤
|
||||
│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ...
|
||||
│(rows 192-255)│
|
||||
└──────────────┘
|
||||
M0 = 4 iterations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## The Tile Distribution Encoding
|
||||
|
||||
Now we can construct the distribution:
|
||||
|
||||
```cpp
|
||||
tile_distribution_encoding<
|
||||
sequence<1>, // [1] Replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // [2] Hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // [3] Parallelism:
|
||||
tuple<sequence<1>, sequence<2, 0>>, // [3] Parallelism
|
||||
sequence<1, 2>, // [4] Yield
|
||||
sequence<0, 1> // [4] Yield
|
||||
>
|
||||
```
|
||||
|
||||
### [1] Replication: `sequence<1>`
|
||||
|
||||
Defines how many times warp patterns are replicated:
|
||||
- `1` = Each warp has a unique pattern (no replication)
|
||||
- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing
|
||||
- `4` = All warps do the same thing
|
||||
|
||||
In our case: `1` means no replication (each warp is independent).
|
||||
|
||||
---
|
||||
|
||||
### [2] Hierarchy: The Multi-Level Structure
|
||||
|
||||
```cpp
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>
|
||||
└───────┬──────────┘ └──────┬────────┘
|
||||
M dimension K dimension
|
||||
```
|
||||
|
||||
**Concrete values:**
|
||||
- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp)
|
||||
- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread)
|
||||
|
||||
---
|
||||
|
||||
### [3] Parallelism: Addressing the Hierarchy
|
||||
|
||||
**The key insight:** Read the tuples **vertically** to understand indexing!
|
||||
|
||||
```cpp
|
||||
tuple<sequence<1>, sequence<1, 2>>
|
||||
tuple<sequence<1>, sequence<2, 0>>
|
||||
```
|
||||
|
||||
#### Reading Pattern
|
||||
|
||||
**Column 1 (Dimension 0 = M):**
|
||||
```
|
||||
sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension)
|
||||
sequence<1>
|
||||
```
|
||||
|
||||
**Column 2 (Dimension 1 = K):**
|
||||
```
|
||||
sequence<1, 2>
|
||||
sequence<2, 0>
|
||||
```
|
||||
[1,2] M2=threads/warp in M dimension
|
||||
[2,0] K0=threads/warp in K dimension
|
||||
|
||||
---
|
||||
|
||||
### [4] Yield Sequences: Output Ordering
|
||||
|
||||
```cpp
|
||||
sequence<1, 2>
|
||||
sequence<0, 1>
|
||||
|
||||
[1,0] means M0=repetitions/warp in M dimension
|
||||
[2,1] means K1=elements/thread in K dimension
|
||||
```
|
||||
---
|
||||
|
||||
## Complete Example: Thread 25 in Warp 0
|
||||
|
||||
Let's trace where **Thread 25** in **Warp 0** reads data:
|
||||
|
||||
### Thread Coordinates
|
||||
- Thread ID in warp: 25
|
||||
- Warp ID in block: 0
|
||||
|
||||
### Decompose Thread 25
|
||||
```
|
||||
Thread 25 in a 2D layout (M2=16, K0=4):
|
||||
Row index: 25 / 4 = 6
|
||||
Col index: 25 % 4 = 1
|
||||
```
|
||||
|
||||
### M Position (Row)
|
||||
```
|
||||
M0 iteration: 0 (first iteration)
|
||||
M1 warp: 0 (warp 0)
|
||||
M2 thread: 6 (6th row in warp)
|
||||
→ M position = 0*64 + 0*16 + 6 = 6
|
||||
```
|
||||
|
||||
### K Position (Column)
|
||||
```
|
||||
K0 thread: 1 (column group 1)
|
||||
K1 elements: 8 (will load 8 consecutive elements)
|
||||
→ K position = 1*8 + [0-7] = elements 8-15
|
||||
```
|
||||
|
||||
**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements).
|
||||
|
||||
---
|
||||
|
||||
## Why This Matters
|
||||
|
||||
### 1. **Memory Coalescing**
|
||||
- Consecutive threads access consecutive memory → efficient global memory access
|
||||
- K dimension uses K1=8 for vectorized loads
|
||||
|
||||
### 2. **Warp Efficiency**
|
||||
- All 64 threads in a warp are utilized
|
||||
- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads
|
||||
|
||||
### 3. **Scalability**
|
||||
- M0 repetitions allow handling larger tiles
|
||||
- Same pattern scales to different sizes
|
||||
|
||||
### 4. **Register Allocation**
|
||||
- Each thread knows exactly how many elements it will hold
|
||||
- Compiler can allocate registers optimally
|
||||
|
||||
---
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Parameter | Value | Meaning |
|
||||
|-----------|-------|---------|
|
||||
| **K1** | 8 | Elements per thread (vector width) |
|
||||
| **K0** | 4 | Threads along K per row |
|
||||
| **M2** | 16 | Threads along M per warp |
|
||||
| **M1** | 4 | Warps per block |
|
||||
| **M0** | 4 | Repetitions of warp pattern |
|
||||
| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) |
|
||||
| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) |
|
||||
| **Elements/Thread** | 32 | M0×K1 = 4×8 |
|
||||
|
||||
---
|
||||
|
||||
## Visualization: Complete Thread Block
|
||||
|
||||
```
|
||||
Block Tile: 256×32
|
||||
|
||||
K dimension (32 elements)
|
||||
├─────────────────────┤
|
||||
0 ┌──────────────────────┐ ┐
|
||||
16 │ Warp 0 │ │
|
||||
32 │ Warp 1 │ │ Iteration 0
|
||||
48 │ Warp 2 │ │ (M0=0)
|
||||
64 │ Warp 3 │ ┘
|
||||
80 ├──────────────────────┤ ┐
|
||||
96 │ Warp 0 │ │
|
||||
112 │ Warp 1 │ │ Iteration 1
|
||||
128 │ Warp 2 │ │ (M0=1)
|
||||
144 │ Warp 3 │ ┘
|
||||
160 ├──────────────────────┤ ┐
|
||||
176 │ Warp 0 │ │
|
||||
192 │ Warp 1 │ │ Iteration 2
|
||||
208 │ Warp 2 │ │ (M0=2)
|
||||
224 │ Warp 3 │ ┘
|
||||
240 ├──────────────────────┤ ┐
|
||||
256 │ Warp 0 │ │
|
||||
│ Warp 1 │ │ Iteration 3
|
||||
│ Warp 2 │ │ (M0=3)
|
||||
│ Warp 3 │ ┘
|
||||
└──────────────────────┘
|
||||
|
||||
Each warp processes 16 rows × 32 cols = 512 elements
|
||||
Each iteration processes 64 rows × 32 cols = 2048 elements
|
||||
Total: 4 iterations × 2048 = 8192 elements ✓
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy
|
||||
2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels
|
||||
3. **Replication controls redundancy**: How many warps share the same pattern
|
||||
4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping
|
||||
|
||||
This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping!
|
||||
|
||||
506
tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md
Normal file
506
tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md
Normal file
@@ -0,0 +1,506 @@
|
||||
# Practice GEMM: Step-by-Step Code Walkthrough
|
||||
|
||||
This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API.
|
||||
|
||||
## Overview
|
||||
|
||||
We'll implement `C = A × B` where:
|
||||
- `A` is an `M × K` matrix
|
||||
- `B` is an `N × K` matrix (note: transposed layout)
|
||||
- `C` is an `M × N` matrix
|
||||
|
||||
The implementation uses a hierarchical tiling strategy with two levels:
|
||||
1. **Block Tiles**: Processed by thread blocks
|
||||
2. **Wave Tiles**: Processed by warps (wavefronts) within blocks
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Define Data Types
|
||||
|
||||
```cpp
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We use `half_t` (FP16) for input matrices A and B.
|
||||
- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy
|
||||
- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity
|
||||
---
|
||||
|
||||
## Step 2: Define Problem Size
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `M = 512`: Number of rows in A and C
|
||||
- `N = 256`: Number of columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
- Strides define memory layout (row-major for A and C, transposed for B)
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Matrix A (M×K): Matrix B (N×K): Matrix C (M×N):
|
||||
[512 rows] [256 rows] [512 rows]
|
||||
[64 cols] [64 cols] [256 cols]
|
||||
stride = K stride = K stride = N
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Create Host Tensors
|
||||
|
||||
```cpp
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We create three tensors on the host (CPU) memory
|
||||
- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`)
|
||||
- `HostTensor` is a CK Tile utility class that manages CPU memory
|
||||
|
||||
**Stride explanation:**
|
||||
- For A: `stride_a = K` means moving to the next row requires skipping K elements
|
||||
- For B: `stride_b = K` means B is stored in transposed format
|
||||
- For C: `stride_c = N` means row-major layout
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Initialize Tensors with Random Data
|
||||
|
||||
```cpp
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- A and B are filled with random values in the range [-5.0, 5.0]
|
||||
- C is initialized to zero (will store the output)
|
||||
|
||||
**Optional: Print Tensor Contents**
|
||||
```cpp
|
||||
// Commented out in the code, but available for debugging:
|
||||
// a_host.print_first_n(10); // Print first 10 elements of A
|
||||
```
|
||||
|
||||
The `print_first_n()` helper function can display tensor contents for debugging purposes.
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Allocate Device Memory and Transfer Data
|
||||
|
||||
```cpp
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `DeviceMem` allocates GPU memory matching the size of host tensors
|
||||
- The constructor **automatically transfers data from host to device**
|
||||
- This is a convenience wrapper around `hipMalloc` and `hipMemcpy`
|
||||
|
||||
**Memory Flow:**
|
||||
```
|
||||
CPU (Host) GPU (Device)
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
│ c_host │ ────────> │c_device │
|
||||
└─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 6: Configure Hierarchical Tiling
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We define a two-level tiling hierarchy for the GEMM computation
|
||||
|
||||
### Block Tile (256 × 128 × 32)
|
||||
- **256**: M dimension per block (rows of A and C)
|
||||
- **128**: N dimension per block (columns of B and C)
|
||||
- **32**: K dimension per block (inner dimension)
|
||||
- Each block tile is processed by one **thread block** (256 threads)
|
||||
|
||||
### Wave Tile (16 × 16 × 16)
|
||||
- **16 × 16**: Output tile dimensions (M × N) per warp iteration
|
||||
- **16**: K dimension per warp iteration
|
||||
- Each wave tile is processed by one **warp** (64 threads on AMD GPUs)
|
||||
|
||||
**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile
|
||||
|
||||
**Important Note:**
|
||||
In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem.
|
||||
|
||||
### Tiling Visualization:
|
||||
|
||||
#### Matrix A (M × K = 256 × 32):
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles (16×16) │
|
||||
│ │ 16│ 16 │ in M×K space │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ │
|
||||
│ │ .. │ .. │ 16 tiles in M │
|
||||
│ ├────┼────┤ 2 tiles in K │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix B (N × K = 128 × 32):
|
||||
```
|
||||
┌──────────────────────────────┐
|
||||
│ One Block Tile (128 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles │
|
||||
│ │ 16│ 16 │ (16×16) │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ 8 tiles in N │
|
||||
│ │ .. │ .. │ 2 tiles in K │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix C (M × N = 256 × 128) - Output:
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 128) │
|
||||
│ │
|
||||
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
|
||||
│ │16× │ │ │ │ │ │ │ │ │
|
||||
│ │ 16 │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
|
||||
│ │
|
||||
│ 16 wave tiles in M direction │
|
||||
│ 8 wave tiles in N direction │
|
||||
│ Total: 128 wave tiles (16×16 each) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### How Wave Tiles Combine (C = A × B):
|
||||
```
|
||||
Matrix A Matrix B (stored transposed N×K) Matrix C
|
||||
(256×32) (128×32) (256×128)
|
||||
|
||||
Row of A tiles: Row of B tiles: One wave tile in C:
|
||||
┌────┬────┐ ┌────┬────┐ ┌────┐
|
||||
│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16)
|
||||
└────┴────┘ └────┴────┘ └────┘
|
||||
16×16 each 16×16 each
|
||||
|
||||
Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ
|
||||
↑ ↑
|
||||
K=0..15 K=16..31
|
||||
|
||||
Each wave tile in C is computed by:
|
||||
- Taking one row of wave tiles from A (2 tiles along K)
|
||||
- Taking one row of wave tiles from B (2 tiles along K)
|
||||
Note: B is stored transposed (N×K), so a "row" in storage corresponds
|
||||
to a "column" in the logical B^T matrix used in computation
|
||||
- Performing dot product: Σ(A_k × B_k^T) for k=0,1
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B
|
||||
- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory
|
||||
- This is the fundamental operation repeated across all 128 wave tiles in C
|
||||
- Each warp computes one wave tile using MFMA instructions
|
||||
|
||||
---
|
||||
|
||||
## Step 7: Create Shape, Problem, and Policy Structs
|
||||
|
||||
```cpp
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. **Shape Struct**
|
||||
Encapsulates all tile shape information (BlockTile and WaveTile dimensions).
|
||||
|
||||
### 2. **Problem Struct**
|
||||
Holds complete problem description:
|
||||
- Data types (ADataType, BDataType, CDataType, AccDataType)
|
||||
- Shape information (BlockTile, WaveTile)
|
||||
|
||||
In more complex examples, this would also include:
|
||||
- Data layouts (row-major, column-major)
|
||||
- Mathematical operations (e.g., transposed GEMM)
|
||||
|
||||
### 3. **Policy Struct**
|
||||
Describes data movement and thread-to-data mapping:
|
||||
- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions
|
||||
- In more complex kernels, includes:
|
||||
- DRAM access patterns
|
||||
- LDS (Local Data Share) usage strategies
|
||||
- Thread distribution within blocks
|
||||
|
||||
**CK Tile Design Pattern:**
|
||||
```
|
||||
Kernel = Problem + Policy + Epilogue
|
||||
↑ ↑ ↑
|
||||
(What) (How) (Post-processing)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 8: Calculate Grid and Block Dimensions
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### Grid Size Calculation
|
||||
```cpp
|
||||
kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N)
|
||||
= ceil(512 / 256) × ceil(256 / 128)
|
||||
= 2 × 2
|
||||
= 4 thread blocks
|
||||
```
|
||||
|
||||
Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction).
|
||||
|
||||
### Block Configuration
|
||||
- `kBlockSize = 256`: Each thread block has 256 threads
|
||||
- 256 threads / 64 threads per warp = **4 warps per block**
|
||||
- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity)
|
||||
|
||||
**Thread Hierarchy:**
|
||||
```
|
||||
GPU
|
||||
└── 1 Thread Block (Grid)
|
||||
└── 256 Threads
|
||||
├── Warp 0 (threads 0-63)
|
||||
├── Warp 1 (threads 64-127)
|
||||
├── Warp 2 (threads 128-191)
|
||||
└── Warp 3 (threads 192-255)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 9: Create and Launch the Kernel
|
||||
|
||||
```cpp
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. Kernel Composition
|
||||
```cpp
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
```
|
||||
The kernel is composed from Problem and Policy structs, following the CK Tile design pattern.
|
||||
|
||||
### 2. Kernel Launch
|
||||
`launch_kernel()` is a CK Tile utility that:
|
||||
- Launches the GPU kernel using HIP runtime
|
||||
- Measures execution time
|
||||
- Returns average execution time in milliseconds
|
||||
|
||||
### 3. Launch Parameters
|
||||
- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled
|
||||
- **Grid size**: `kGridSize = 1` - number of thread blocks
|
||||
- **Block size**: `kBlockSize = 256` - threads per block
|
||||
- **Shared memory**: `0` - no dynamic shared memory in this example
|
||||
- **Kernel arguments**: Device pointers and problem dimensions
|
||||
|
||||
### 4. Kernel Execution Flow
|
||||
```
|
||||
launch_kernel() calls gemm_kernel.operator()()
|
||||
↓
|
||||
PracticeGemmKernel::operator()
|
||||
↓
|
||||
Creates tensor views over device memory
|
||||
↓
|
||||
Calls block-level pipeline
|
||||
↓
|
||||
Block pipeline calls warp-level pipeline
|
||||
↓
|
||||
Warp pipeline calls MFMA instructions
|
||||
↓
|
||||
Results written back to C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 10: Verify Results
|
||||
|
||||
```cpp
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// Reference gemm on CPU
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
|
||||
// Copy GPU results back to host
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
|
||||
// Compare results
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. CPU Reference Implementation
|
||||
```cpp
|
||||
reference_basic_gemm<...>(a_host, b_host, c_host_ref);
|
||||
```
|
||||
Computes GEMM on CPU using a simple nested loop implementation (ground truth).
|
||||
|
||||
### 2. Copy GPU Results to Host
|
||||
```cpp
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
```
|
||||
Transfers the computed result from GPU memory back to CPU for comparison.
|
||||
|
||||
### 3. Error Checking
|
||||
```cpp
|
||||
ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
```
|
||||
Compares GPU and CPU results element-wise with tolerance:
|
||||
- **Relative error**: 1e-3 (0.1%)
|
||||
- **Absolute error**: 1e-3
|
||||
|
||||
**Verification Flow:**
|
||||
```
|
||||
CPU GPU
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
└─────────┘ └─────────┘
|
||||
│ │
|
||||
↓ ↓
|
||||
reference_gemm() GPU kernel
|
||||
│ │
|
||||
↓ ↓
|
||||
┌──────────┐ ┌──────────┐
|
||||
│c_host_ref│ │c_device │
|
||||
└──────────┘ └──────────┘
|
||||
│ │
|
||||
│ ↓
|
||||
│ FromDevice()
|
||||
│ │
|
||||
↓ ↓
|
||||
└────> check_err() <───┘
|
||||
│
|
||||
↓
|
||||
Pass/Fail
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Execution Flow Summary
|
||||
|
||||
```
|
||||
1. Define data types (FP16 inputs, FP32 output)
|
||||
↓
|
||||
2. Set problem size (M=256, N=128, K=32)
|
||||
↓
|
||||
3. Create host tensors and initialize with random data
|
||||
↓
|
||||
4. Allocate device memory and transfer data (CPU → GPU)
|
||||
↓
|
||||
5. Configure hierarchical tiling (BlockTile, WaveTile)
|
||||
↓
|
||||
6. Create Shape, Problem, and Policy structs
|
||||
↓
|
||||
7. Calculate grid/block dimensions (1 block, 256 threads)
|
||||
↓
|
||||
8. Compose and launch kernel (Problem + Policy)
|
||||
↓
|
||||
9. Execute GEMM on GPU
|
||||
│ ├─ Block-level pipeline
|
||||
│ ├─ Warp-level pipeline
|
||||
│ └─ MFMA instructions
|
||||
↓
|
||||
10. Verify results (compare GPU vs CPU reference)
|
||||
↓
|
||||
11. Calculate and print performance metrics
|
||||
↓
|
||||
12. Return success/failure
|
||||
```
|
||||
|
||||
---
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../warp_level/block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
// 3d + no padding (NAIVE_IMPLEMENTATION)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding (NAIVE_IMPLEMENTATION)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
155
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp
Normal file
155
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "practice_gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* Naive GEMM implementation (no optimizations)
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Naive implementation test ***\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// NAIVE_IMPLEMENTATION uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// NAIVE_IMPLEMENTATION uses mfma m32 n32 k8
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
17
tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_PADDING_K_FIRST "tile_tutorial_padding_k_first")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_PADDING_K_FIRST}")
|
||||
|
||||
add_executable(${EXAMPLE_PADDING_K_FIRST} EXCLUDE_FROM_ALL gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_PADDING_K_FIRST})
|
||||
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmASmemBSmemCReg with MFMA_32x32x16 (8x2) instruction
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// KERNEL_A uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// KERNEL_A uses mfma m32 n32 k16 (8x2 variant)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
// 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding for B (PADDING_K_FIRST only pads A in version2)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// B uses same layout as NAIVE (no padding)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
158
tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp
Normal file
158
tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* KERNEL_A: GEMM with PADDING_K_FIRST + MFMA_32x32x16 (8x2)
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Kernel A test ***\n");
|
||||
printf(" --> Using PADDING_K_FIRST\n");
|
||||
printf(" --> Using mfma_32x32x(8x2)\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
17
tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_MFMA_16X16X16 "tile_tutorial_mfma_16x16x16")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_MFMA_16X16X16}")
|
||||
|
||||
add_executable(${EXAMPLE_MFMA_16X16X16} EXCLUDE_FROM_ALL gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_MFMA_16X16X16} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_MFMA_16X16X16} PRIVATE ${EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_MFMA_16X16X16})
|
||||
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmASmemBSmemCReg with MFMA_16x16x16 instruction
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// KERNEL_B uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// KERNEL_B uses mfma m16 n16 k16
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
// 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding for B (PADDING_K_FIRST only pads A in version2)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// B uses same layout as NAIVE (no padding)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
158
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp
Normal file
158
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* KERNEL_B: GEMM with PADDING_K_FIRST + MFMA_16x16x16
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Kernel B test ***\n");
|
||||
printf(" --> Using PADDING_K_FIRST\n");
|
||||
printf(" --> Using mfma_16x16x16\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
10
tutorial/ck_tile/gemm/CMakeLists.txt
Normal file
10
tutorial/ck_tile/gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
add_subdirectory(01_naive_gemm)
|
||||
add_subdirectory(02_padding_k_first)
|
||||
add_subdirectory(03_mfma_16x16x16)
|
||||
37
tutorial/ck_tile/gemm/reference_gemm.hpp
Normal file
37
tutorial/ck_tile/gemm/reference_gemm.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n)
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int K = b_n_k.mDesc.get_lengths()[1];
|
||||
|
||||
auto f = [&](auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_n_k(n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
Reference in New Issue
Block a user