From c73c50a96e3d9f396e45c4f8935c81dd074e66ff Mon Sep 17 00:00:00 2001 From: Aviral Goel <191153937+AviralGoelAMD@users.noreply.github.com> Date: Wed, 27 May 2026 18:11:21 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7714 (commit 13ae6d6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE] Restructure naive GEMM tutorial and add tile distribution tutorials (#7714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Flatten naive GEMM tutorial directory structure (remove `block_level/`, `host_level/`, `warp_level/` subdirs) to match the composable_kernel repo layout - Add `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` macro switch to toggle between standard and transposed WarpGemm variants - Consolidate 6 verbose markdown files (~2600 lines) into one concise README (~120 lines) - Add 3 tile distribution encoding tutorials with step-by-step "How to read Ps/Ys" annotations: - Tutorial 1: A-matrix DRAM load (256×32) — NDimP=2, coalesced K-splitting - Tutorial 2: B-matrix DRAM load (128×32) — same pattern, fewer iterations - Tutorial 3: C-matrix register layout (32×32) — MFMA m32n32k8 hardware output mapping, standard vs transposed - Tile distribution tutorials guarded to build only for gfx942 and gfx950 --- tutorial/ck_tile/CMakeLists.txt | 1 + .../01_naive_gemm/BLOCK_LEVEL_PIPELINE.md | 589 ----------------- .../gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md | 618 ------------------ .../gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md | 464 ------------- tutorial/ck_tile/gemm/01_naive_gemm/README.md | 217 +++--- .../gemm/01_naive_gemm/TILE_DISTRIBUTION.md | 312 --------- .../ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md | 506 -------------- .../block_gemm_asmem_bsmem_creg.hpp | 0 .../block_gemm_asmem_bsmem_creg_policy.hpp | 19 +- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 0 ..._gemm_pipeline_agmem_bgmem_creg_policy.hpp | 2 +- .../{host_level => }/grid_gemm.hpp | 0 .../gemm/01_naive_gemm/practice_gemm.hpp | 4 +- .../ck_tile/tile_distribution/CMakeLists.txt | 21 + tutorial/ck_tile/tile_distribution/README.md | 63 ++ .../tile_distribution/tile_distribution_1.cpp | 285 ++++++++ .../tile_distribution/tile_distribution_2.cpp | 240 +++++++ .../tile_distribution/tile_distribution_3.cpp | 376 +++++++++++ 18 files changed, 1097 insertions(+), 2620 deletions(-) delete mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md delete mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md delete mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md delete mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md delete mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md rename tutorial/ck_tile/gemm/01_naive_gemm/{warp_level => }/block_gemm_asmem_bsmem_creg.hpp (100%) rename tutorial/ck_tile/gemm/01_naive_gemm/{warp_level => }/block_gemm_asmem_bsmem_creg_policy.hpp (67%) rename tutorial/ck_tile/gemm/01_naive_gemm/{block_level => }/block_gemm_pipeline_agmem_bgmem_creg.hpp (100%) rename tutorial/ck_tile/gemm/01_naive_gemm/{block_level => }/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp (98%) rename tutorial/ck_tile/gemm/01_naive_gemm/{host_level => }/grid_gemm.hpp (100%) create mode 100644 tutorial/ck_tile/tile_distribution/CMakeLists.txt create mode 100644 tutorial/ck_tile/tile_distribution/README.md create mode 100644 tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp create mode 100644 tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp create mode 100644 tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt index 239270d833..208f8989a9 100644 --- a/tutorial/ck_tile/CMakeLists.txt +++ b/tutorial/ck_tile/CMakeLists.txt @@ -7,3 +7,4 @@ include_directories(AFTER add_subdirectory(00_copy_kernel) add_subdirectory(gemm) +add_subdirectory(tile_distribution) diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md b/tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md deleted file mode 100644 index 114fccfd56..0000000000 --- a/tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md +++ /dev/null @@ -1,589 +0,0 @@ -# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg - -## Overview - -The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates: -1. **Data movement** from DRAM → Registers → LDS -2. **GEMM computation** using data in LDS -3. **Iteration** over the K dimension when needed - -This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C. - ---- - -## Architecture: Problem and Policy - -Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern: - -### Problem: `PracticeGemmBlockPipelineProblem` -Contains: -- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType` -- **Shape information**: `BlockTile` and `WaveTile` dimensions - -### Policy: `PracticeGemmBlockPolicy` -Contains strategies for: -1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`) - - Defines how 256 threads in a block map to elements of a block tile - - Each thread knows which elements to load/store from DRAM to its registers - - We'll cover tile distribution construction in detail later - -2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`) - - Describes how data is logically organized in Local Data Share (LDS) - - Optimizes for bank conflict avoidance and efficient access patterns - - We'll cover LDS descriptor construction in detail later - -3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`) - - Returns the warp-level GEMM implementation - ---- - -## Inputs and Outputs - -```cpp -template -CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const -``` - -### Inputs: -- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock) -- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock) -- `num_loop`: Number of iterations along K dimension -- `p_smem`: Pointer to shared memory (LDS) - -### Output: -- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs) - ---- - -## Step-by-Step Walkthrough - -### Step 1: Create LDS Tensor Views - -```cpp -// A tile in LDS -ADataType* p_a_lds = static_cast(p_smem); -constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); -auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - -// B tile in LDS (placed after A in shared memory) -BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); -constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); -auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); -``` - -**What's happening:** -- We partition the shared memory (`p_smem`) into two regions: one for A, one for B -- We create **tensor views** over these LDS regions using descriptors from the policy -- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory - -**Memory Layout:** -``` -Shared Memory (LDS): -┌─────────────────────┬─────────────────────┐ -│ A Block Tile │ B Block Tile │ -│ (256×32 fp16) │ (128×32 fp16) │ -└─────────────────────┴─────────────────────┘ -↑ ↑ -p_a_lds p_b_lds -``` - ---- - -### Step 2: Create Tile Windows for Data Movement - -We create **6 tile windows** for different purposes: - -#### 2a. DRAM → Registers (Load from DRAM) - -```cpp -auto a_copy_dram_window = make_tile_window( - a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), // 256×32 - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); // ← Tile distribution! -``` - -**Key Points:** -- `a_copy_dram_window` is a `tile_window_with_static_distribution` -- The **tile distribution** tells each thread which elements to load from DRAM -- This window will **slide along the K dimension** in the loop - -#### 2b. Registers → LDS (Store to LDS) - -```cpp -auto a_copy_lds_window = make_tile_window( - a_lds_block, - make_tuple(number{}, number{}), // 256×32 - {0, 0}, // Origin at (0, 0) in LDS - a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM! -``` - -**Key Points:** -- Uses the **same tile distribution** as `a_copy_dram_window` -- This ensures each thread stores to LDS in the same pattern it loaded from DRAM -- Origin is always `{0, 0}` because LDS is reused for each K iteration - -#### 2c. LDS → Registers (GEMM Input) - -```cpp -auto a_lds_gemm_window = make_tile_window( - a_lds_block, - make_tuple(number{}, number{}), - {0, 0}); // No tile distribution! -``` - -**Key Points:** -- This is a `tile_window_with_static_lengths` (no explicit distribution) -- Used as input to the warp-level GEMM -- The warp GEMM will handle its own thread mapping internally - -**Similar windows are created for B:** -- `b_copy_dram_window`: Load B from DRAM -- `b_copy_lds_window`: Store B to LDS -- `b_lds_gemm_window`: Read B from LDS for GEMM - ---- - -### Step 3: Create Distributed Tensors (VGPRs) - -```cpp -using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); -using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - -using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); -using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); - -ABlockTile a_block_tile; // Per-thread registers for A -BBlockTile b_block_tile; // Per-thread registers for B -``` - -#### What is `make_static_distributed_tensor`? - -**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**. - -**Key Properties:** -1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers -2. **Compile-time sized**: Buffer size determined by tile distribution at compile time -3. **Zero-overhead**: All indexing and layout transformations happen at compile time - -**How it works:** - -```cpp -template -struct static_distributed_tensor -{ - using DataType = remove_cvref_t; - using StaticTileDistribution = remove_cvref_t; - - // Calculate per-thread storage size from tile distribution - using ThreadTensorDesc = - remove_cvref_t; - - static constexpr index_t kThreadElementSpaceSize = - ThreadTensorDesc{}.get_element_space_size(); - - // Per-thread register array (VGPRs) - thread_buffer thread_buf_; -}; -``` - -**The tile distribution defines:** -- **Which elements each thread owns** in the tile -- **How many elements** each thread stores (buffer size) -- **How elements are laid out** in each thread's registers - -**Concrete Example for 256×32 tile with 256 threads:** - -``` -Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values) -Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values) -Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values) -... -Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values) -``` - -**Collectively:** -- All 256 threads together hold the **entire 256×32 tile** (8192 elements) -- Each thread's buffer lives in its **own VGPRs** -- No two threads own the same element - -**Distributed Ownership Analogy:** -Think of a tile as a **jigsaw puzzle**: -- The **tile distribution** is the cutting pattern -- Each **thread** gets one puzzle piece (its slice) -- Each **`static_distributed_tensor`** is a box holding all pieces -- Each thread's **`thread_buf_`** is its individual piece in its own registers - ---- - -### Step 4: The GEMM Loop - -```cpp -// Initialize C accumulator to zero -auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; -tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - -index_t iCounter = num_loop; // Number of K iterations - -while(iCounter > 0) -{ - // 1. Load from DRAM to registers - a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs - b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs - - // 2. Move windows for next iteration - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32) - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32) - - // 3. Store from registers to LDS - store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS - store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS - - // 4. Synchronize threads (ensure all data is in LDS) - block_sync_lds(); - - // 5. Compute GEMM using data in LDS - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - // 6. Synchronize threads (before overwriting LDS in next iteration) - block_sync_lds(); - - iCounter--; -} - -return c_block_tile; // Return accumulated result in registers -``` - ---- - -## Detailed Loop Breakdown - -### Phase 1: Load (DRAM → VGPRs) - -```cpp -a_block_tile = load_tile(a_copy_dram_window); -``` - -**What happens:** -1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution) -2. Data is loaded into **per-thread registers** (VGPRs) -3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once) - -**Example for Thread 0:** -``` -Thread 0 loads: - A[0,0:7] (8 fp16 values, one vector load) - A[1,0:7] (8 fp16 values, one vector load) - ... -``` - -### Phase 2: Move Windows - -```cpp -constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); -move_tile_window(a_copy_dram_window, a_dram_tile_window_step); -``` - -**What happens:** -- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example) -- This prepares for the next K iteration -- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ... - -**Visualization for Problem Size 512×256×64:** -``` -Matrix A (512×64): -┌─────────────────────────────────────┐ -│ Block 0: rows 0-255 │ -│ ┌──────────┬──────────┐ │ -│ │ K=0:31 │ K=32:63 │ │ ← Window slides right -│ │ Iter 0 │ Iter 1 │ │ -│ └──────────┴──────────┘ │ -└─────────────────────────────────────┘ -``` - -### Phase 3: Store (VGPRs → LDS) - -```cpp -store_tile(a_copy_lds_window, a_block_tile); -``` - -**What happens:** -1. Each thread writes **its elements** from registers to LDS -2. Uses the **same distribution** as the DRAM load -3. Data is now in **shared memory**, accessible to all threads in the block - -**Why this step?** -- GEMM computation needs **all threads** to access **all data** -- Registers are per-thread; LDS is shared across the block -- LDS acts as a "staging area" for collaborative computation - -### Phase 4: Synchronize - -```cpp -block_sync_lds(); -``` - -**What happens:** -- All threads in the block **wait** until everyone has finished storing to LDS -- Ensures no thread starts reading from LDS before all writes are complete -- Critical for correctness! - -### Phase 5: GEMM Computation - -```cpp -block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); -``` - -**What happens:** -1. The warp-level GEMM reads data from LDS -2. Performs matrix multiplication using MFMA instructions -3. Accumulates results into `c_block_tile` (in registers) - -**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results. - -### Phase 6: Synchronize Again - -```cpp -block_sync_lds(); -``` - -**What happens:** -- Ensures all threads have finished reading from LDS -- Safe to overwrite LDS in the next iteration - ---- - -## Memory Flow Diagram - -``` -Iteration 0 (K=0:31): -┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ -│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ -│ A[0:255,│ │ (per- │ │ A_block │ -│ 0:31] │ │ thread) │ │ │ -└─────────┘ └──────────┘ └─────────┘ - │ - │ block_gemm - ↓ - ┌──────────┐ - │ c_block_ │ - │ tile │ - │ (VGPRs) │ - └──────────┘ - -Iteration 1 (K=32:63): -┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ -│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ -│ A[0:255,│ │ (per- │ │ A_block │ -│ 32:63] │ │ thread) │ │ (reused)│ -└─────────┘ └──────────┘ └─────────┘ - │ - │ block_gemm - ↓ - ┌──────────┐ - │ c_block_ │ - │ tile │ - │ (accum.) │ - └──────────┘ -``` - ---- - -## Example: Problem Size 512×256×64 - -### Block 0 Computation - -**Input:** -- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially -- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed) -- `num_loop`: 2 (since K=64, KPerBlock=32) - -**Iteration 0:** -1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs -2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63] -3. Store to LDS -4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T` - -**Iteration 1:** -1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs -2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends) -3. Store to LDS -4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T` - -**Output:** -- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers - ---- - -## Key Concepts Summary - -### 1. Tile Distribution -- **Maps threads to data elements** for load/store operations -- Each thread knows exactly which elements it's responsible for -- Enables **parallel, vectorized** memory access -- **Same distribution** used for DRAM load and LDS store - -### 2. Static Distributed Tensor -- **Per-thread register storage** (VGPRs) -- Each thread owns a **different slice** of the tile -- **Compile-time sized** for zero-overhead abstraction -- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile` - -### 3. Tile Window Movement -- Windows **slide** over larger tensors -- Enables iteration over the K dimension -- `move_tile_window(window, step)` updates the origin - -### 4. LDS as Staging Area -- **Shared memory** accessible to all threads in a block -- Required because GEMM needs all threads to access all data -- **Reused** across K iterations (same LDS buffer) - -### 5. Synchronization -- `block_sync_lds()` ensures memory consistency -- **Before GEMM**: All stores to LDS are complete -- **After GEMM**: All reads from LDS are complete - ---- - -## Deep Dive: `static_distributed_tensor` Mechanics - -### How Tile Distribution Creates Per-Thread Storage - -When you call: -```cpp -using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); -ABlockTile a_block_tile; -``` - -**Step 1: Extract Thread Tensor Descriptor** - -The tile distribution contains a `ys_to_d_descriptor` that maps: -- **Y dimensions** (logical tile coordinates, e.g., M, K) -- **D dimension** (per-thread register index, linearized) - -```cpp -using ThreadTensorDesc = - decltype(StaticTileDistribution{}.get_ys_to_d_descriptor()); -``` - -**Step 2: Calculate Per-Thread Buffer Size** - -```cpp -static constexpr index_t kThreadElementSpaceSize = - ThreadTensorDesc{}.get_element_space_size(); - -static constexpr index_t get_thread_buffer_size() -{ - return kThreadElementSpaceSize / PackedSize; -} -``` - -**Example:** -- 256×32 tile distributed across 256 threads -- Each thread owns 32 elements (one row) -- `thread_buffer_size = 32` (for PackedSize=1) - -**Step 3: Allocate Thread Buffer** - -```cpp -thread_buffer thread_buf_; -``` - -This is essentially: -```cpp -fp16_t data[32]; // Per-thread register array (VGPRs) -``` - -### Usage in Load/Store Operations - -**Load from DRAM:** -```cpp -a_block_tile = load_tile(a_copy_dram_window); -``` - -What happens internally: -1. Each thread queries the tile distribution: "Which elements do I own?" -2. Thread 0 learns it owns A[0,0:31] -3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]` -4. All 256 threads do this **in parallel** - -**Store to LDS:** -```cpp -store_tile(a_copy_lds_window, a_block_tile); -``` - -What happens internally: -1. Each thread reads from its `a_block_tile.thread_buf_` -2. Thread 0 writes A[0,0:31] from its registers to LDS -3. All 256 threads do this **in parallel** -4. After `block_sync_lds()`, the entire tile is in shared LDS - -### Distributed Indexing - -The `static_distributed_tensor` supports compile-time indexing: - -```cpp -// Access using distributed indices -auto value = a_block_tile(tile_distributed_index{}); -``` - -Internally: -1. Convert distributed index → Y index (logical tile coordinates) -2. Calculate buffer offset using `ThreadTensorDesc` -3. Access `thread_buf_[offset]` - -All of this happens **at compile time** with zero runtime overhead! - -### Why This Design? - -**Benefits:** -1. **Parallel Memory Access**: All threads load/store simultaneously -2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once) -3. **Zero Overhead**: All indexing resolved at compile time -4. **Type Safety**: Distribution mismatch caught at compile time -5. **Register Pressure**: Compiler knows exact VGPR usage - -**Trade-offs:** -- Requires compile-time tile sizes -- Distribution must be static -- More complex type system - -### Memory Hierarchy Summary - -``` -┌─────────────────────────────────────────────────────────────┐ -│ DRAM (Global Memory) │ -│ Full matrices A, B, C │ -└─────────────────────────────────────────────────────────────┘ - │ - │ load_tile (parallel, vectorized) - ↓ -┌─────────────────────────────────────────────────────────────┐ -│ VGPRs (Per-Thread Registers) │ -│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │ -│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │ -│ ... │ -│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │ -│ │ -│ ← static_distributed_tensor manages this distribution │ -└─────────────────────────────────────────────────────────────┘ - │ - │ store_tile (parallel, vectorized) - ↓ -┌─────────────────────────────────────────────────────────────┐ -│ LDS (Shared Memory) │ -│ Entire block tile (256×32) │ -│ Accessible to all threads in block │ -└─────────────────────────────────────────────────────────────┘ -``` - -**Key Insight:** -`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time. - - - diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md b/tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md deleted file mode 100644 index 43cb01fb36..0000000000 --- a/tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md +++ /dev/null @@ -1,618 +0,0 @@ -# Host-Level Pipeline: Orchestrating Block-Level GEMM - -This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation. - -## Overview - -The host-level pipeline is responsible for: -1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C -2. **Block-to-tile mapping**: Assigning each thread block to a specific tile -3. **Creating tile windows**: Establishing sliding windows over tensor views -4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM -5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM - -```cpp -template -struct PracticeGemmHostPipeline -{ - template - CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, - const BDRAMTensorView& b_dram, - CDRAMTensorView& c_dram) const - { - // 1. Calculate problem dimensions and tile coverage - // 2. Map thread block to tile coordinates - // 3. Create tile windows over A and B - // 4. Call block-level pipeline to compute - // 5. Store result to C - } -}; -``` - ---- - -## Step 1: Calculate Problem Dimensions and Tile Coverage - -```cpp -// Size of the entire problem -const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K -const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N -const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K - -// Size of the block tile -const auto MPerBlock = BlockTile::at(number<0>{}); // 256 -const auto NPerBlock = BlockTile::at(number<1>{}); // 128 -const auto KPerBlock = BlockTile::at(number<2>{}); // 32 - -// Number of block tiles needed to cover C matrix -const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2 -const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2 -``` - -### What's Happening: - -1. **Extract problem dimensions** from tensor descriptors: - - `M = 512`: Rows in A and C - - `N = 256`: Columns in B and C - - `K = 64`: Inner dimension (columns of A, rows of B) - -2. **Get block tile sizes** from the `BlockTile` configuration: - - `MPerBlock = 256`: Each block processes 256 rows - - `NPerBlock = 128`: Each block processes 128 columns - - `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration - -3. **Calculate tile coverage**: - - `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction - - `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction - - **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**! - -### Visual Representation: - -``` -Matrix C (512 × 256): -┌──────────────────────┬──────────────────────┐ -│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2 -│ 256×128 │ 256×128 │ -│ Block 0 │ Block 1 │ -│ │ │ -├──────────────────────┼──────────────────────┤ -│ Tile (1,0) │ Tile (1,1) │ -│ 256×128 │ 256×128 │ -│ Block 2 │ Block 3 │ -│ │ │ -└──────────────────────┴──────────────────────┘ - ↑ - num_tile_m = 2 - -Total blocks needed = 2 × 2 = 4 blocks - -Each block computes one 256×128 tile of the output matrix C. -``` - -### How Blocks Cover Matrices A and B: - -``` -Matrix A (512 × 64): Matrix B (256 × 64): -┌─────────────┬──────┐ ┌─────────────┬──────┐ -│ Block 0,2 │ K │ │ Block 0,1 │ K │ -│ uses rows │ → │ │ uses rows │ → │ -│ 0-255 │ │ │ 0-127 │ │ -├─────────────┼──────┤ ├─────────────┼──────┤ -│ Block 1,3 │ K │ │ Block 2,3 │ K │ -│ uses rows │ → │ │ uses rows │ → │ -│ 256-511 │ │ │ 128-255 │ │ -└─────────────┴──────┘ └─────────────┴──────┘ - 256 rows 64 cols 128 rows 64 cols - -Each block needs to iterate over K dimension (64/32 = 2 iterations) -``` - ---- - -## Step 2: Map Thread Block to Tile Coordinates - -```cpp -// Get block id (0 to total_blocks - 1) -const auto id_block = get_block_id(); - -// Map block id to 2D tile coordinates -const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); -const auto tile_id = block2tile(id_block); - -const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate -const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate -``` - -### What's Happening: - -Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates. - -### The `MakeBlock2TileMap` Function: - -```cpp -CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) -{ - // Create a merge transform: (N0, M0) → linear index - const auto unmerge = make_merge_transform(make_tuple(N0, M0)); - - return [unmerge](index_t block_id) { - multi_index<2> unmerged; - // Convert linear block_id back to 2D coordinates - unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); - - // Return (m_idx, n_idx) - note the swap! - return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); - }; -} -``` - -### In Our Example (2×2 Grid): - -```cpp -// Block 0: -id_block = 0 -tile_id = block2tile(0) = (0, 0) // Top-left tile -tile_id_m = 0, tile_id_n = 0 - -// Block 1: -id_block = 1 -tile_id = block2tile(1) = (1, 0) // Bottom-left tile -tile_id_m = 1, tile_id_n = 0 - -// Block 2: -id_block = 2 -tile_id = block2tile(2) = (0, 1) // Top-right tile -tile_id_m = 0, tile_id_n = 1 - -// Block 3: -id_block = 3 -tile_id = block2tile(3) = (1, 1) // Bottom-right tile -tile_id_m = 1, tile_id_n = 1 -``` - -**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing! - ---- - -## Step 3: Calculate Tile Origin and Create Tile Windows - -```cpp -// Calculate the starting position of this tile in the global matrix -const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256 -const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128 - -// Create tile windows over A and B tensor views -const auto a_block_window = make_tile_window( - a_dram, // Tensor view over A - make_tuple(number{}, number{}), // Window size: 256×32 - {tile_origin_m, 0} // Origin: varies by block -); - -const auto b_block_window = make_tile_window( - b_dram, // Tensor view over B - make_tuple(number{}, number{}), // Window size: 128×32 - {tile_origin_n, 0} // Origin: varies by block -); -``` - -### Tile Origins for Each Block: - -```cpp -// Block 0 (Tile 0,0): -tile_origin_m = 0 * 256 = 0 -tile_origin_n = 0 * 128 = 0 -a_block_window origin: (0, 0) → covers A rows 0-255 -b_block_window origin: (0, 0) → covers B rows 0-127 - -// Block 1 (Tile 1,0): -tile_origin_m = 1 * 256 = 256 -tile_origin_n = 0 * 128 = 0 -a_block_window origin: (256, 0) → covers A rows 256-511 -b_block_window origin: (0, 0) → covers B rows 0-127 - -// Block 2 (Tile 0,1): -tile_origin_m = 0 * 256 = 0 -tile_origin_n = 1 * 128 = 128 -a_block_window origin: (0, 0) → covers A rows 0-255 -b_block_window origin: (128, 0) → covers B rows 128-255 - -// Block 3 (Tile 1,1): -tile_origin_m = 1 * 256 = 256 -tile_origin_n = 1 * 128 = 128 -a_block_window origin: (256, 0) → covers A rows 256-511 -b_block_window origin: (128, 0) → covers B rows 128-255 -``` - -### What are Tile Windows? - -A **tile window** is a **sliding window** over a larger tensor view. It: -- Defines a **rectangular region** within the tensor -- Has a **fixed size** (e.g., 256×32 for A) -- Has an **origin** (starting position) -- Can be **moved** to access different regions -### Visual Representation (Block 0 Example): - -``` -Matrix A (512 × 64): Matrix B (256 × 64): -┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐ -│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │ -│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │ -│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │ -│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │ -│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │ -│ │ │ ├─────────────┼─────────────┤ -├─────────────┼─────────────┤ │ │ │ -│ │ │ │ │ │ -│ │ │ │ │ │ -│ │ │ │ │ │ -└─────────────┴─────────────┘ └─────────────┴─────────────┘ - Origin: (0, 0) Origin: (0, 0) - Covers rows 0-255 Covers rows 0-127 - Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration) -``` - -**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration. - -### Tile Window Properties: - -```cpp -// Tile window structure (conceptual): -struct tile_window { - TensorView& tensor_view; // Reference to underlying tensor - Tuple window_lengths; // Size of the window (256, 32) - MultiIndex window_origin; // Starting position (0, 0) - - // Can move the window: - void move(MultiIndex step); // Shift window by step - - // Access data through the window: - auto load(); // Load data from windowed region -}; -``` - - -### Tile Window Movement: Iterating Over K Dimension - -In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension: - -``` -Matrix A (512 × 64) - Block 0's view: -┌─────────────┬─────────────┐ -│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ -│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K -│ ┃ 256×32 ┃ │ ║ 256×32 ║ │ -│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ -│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ -├─────────────┼─────────────┤ -│ │ │ -│ Block 1's │ │ -│ region │ │ -└─────────────┴─────────────┘ - -Matrix B (256 × 64) - Block 0's view: -┌─────────────┬─────────────┐ -│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ -│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ -│ ┃ 128×32 ┃ │ ║ 128×32 ║ │ -│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ -│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ -├─────────────┼─────────────┤ -│ Block 2's │ │ -│ region │ │ -└─────────────┴─────────────┘ -``` - -### How Windows Move (Conceptual - handled by block pipeline): - -```cpp -// Iteration 0: -a_block_window origin: (tile_origin_m, 0) // K columns 0-31 -b_block_window origin: (tile_origin_n, 0) // K columns 0-31 -// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] - -// Move windows to next K position: -move_tile_window(a_block_window, {0, 32}); -move_tile_window(b_block_window, {0, 32}); - -// Iteration 1: -a_block_window origin: (tile_origin_m, 32) // K columns 32-63 -b_block_window origin: (tile_origin_n, 32) // K columns 32-63 -// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] - -// Final result: -// C_tile = C_partial_0 + C_partial_1 -``` - -**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C. - ---- - -## Step 4: Delegate to Block-Level Pipeline - -```cpp -// Get the block-level pipeline from policy -constexpr auto block_gemm_pipeline = - Policy::template GetPracticeGemmBlockPipeline(); - -// Calculate number of K iterations needed -int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2 - -// Allocate shared memory (LDS) for block-level computation -__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; - -// Call block-level pipeline to compute C tile -const auto c_block_tile = - block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); -``` - -### What's Happening: - -1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation -2. **Calculate K iterations**: How many times to iterate over the K dimension - - In our example: `K=64, KPerBlock=32` → **2 iterations** - - Each iteration processes 32 elements of the K dimension - - Results are accumulated across iterations - -3. **Allocate shared memory**: - - `__shared__` declares memory shared by all threads in the block - - `GetStaticLDSSize()` returns the required size in bytes - - This memory is used for: - - Staging data from DRAM → LDS - - Cooperative loading by threads - - Fast access during computation - -4. **Execute block pipeline**: - - Takes A and B tile windows as input - - Performs the GEMM computation: `C_tile = A_tile × B_tile` - - Returns result in `c_block_tile` (stored in VGPRs - registers) - -### Memory Hierarchy During Computation: - -``` -┌─────────────────────────────────────────────────────────────┐ -│ DRAM (Global Memory) - Slowest, Largest │ -│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ -│ │ A matrix │ │ B matrix │ │ C matrix │ │ -│ └─────────────┘ └─────────────┘ └─────────────┘ │ -└─────────────────────────────────────────────────────────────┘ - ↓ load ↓ load ↑ store -┌─────────────────────────────────────────────────────────────┐ -│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │ -│ ┌─────────────┐ ┌─────────────┐ │ -│ │ A_tile │ │ B_tile │ ← Staged here │ -│ │ (p_smem) │ │ (p_smem) │ │ -│ └─────────────┘ └─────────────┘ │ -└─────────────────────────────────────────────────────────────┘ - ↓ load ↓ load -┌─────────────────────────────────────────────────────────────┐ -│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │ -│ ┌─────────────────────────────────────────────────────────┐ │ -│ │ c_block_tile (accumulated result) │ │ -│ │ Computation happens here using MFMA instructions │ │ -│ └─────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────┘ -``` - -### Block Pipeline Responsibilities: - -The block pipeline (called here) will: -1. Load A and B tiles from DRAM → LDS (cooperative loading) -2. Distribute work among warps -3. Each warp loads its portion from LDS → VGPRs -4. Perform MFMA operations: `C += A × B` -5. Accumulate results in VGPRs -6. Return final `c_block_tile` in registers - ---- - -## Step 5: Store Results to DRAM - -```cpp -// Create a tile window over C for writing results -auto c_window = make_tile_window( - c_dram, // Tensor view over C - make_tuple(number{}, number{}), // Window size: 256×128 - {tile_origin_m, tile_origin_n} // Origin: varies by block -); - -// Store computed tile from VGPRs to DRAM -store_tile(c_window, c_block_tile); -``` - -### C Window Origins for Each Block: - -```cpp -// Block 0: Writes to top-left tile -c_window origin: (0, 0) → writes to C[0:255, 0:127] - -// Block 1: Writes to bottom-left tile -c_window origin: (256, 0) → writes to C[256:511, 0:127] - -// Block 2: Writes to top-right tile -c_window origin: (0, 128) → writes to C[0:255, 128:255] - -// Block 3: Writes to bottom-right tile -c_window origin: (256, 128) → writes to C[256:511, 128:255] -``` - -### What's Happening: - -1. **Create C tile window**: - - Size: 256×128 (matches our block tile size) - - Origin: Varies by block - each block writes to its assigned region - - This window defines **where** to write the results - -2. **Store tile to DRAM**: - - `c_block_tile`: Computed results in VGPRs (registers) - - `c_window`: Destination window in DRAM - - `store_tile()`: Efficiently writes data from registers → DRAM - -### The `store_tile` Function: - -Recall from our earlier discussion, `store_tile` does: - -```cpp -template -void store_tile(TileWindow& tile_window_tmp, - const DistributedTensor& dstr_tensor) -{ - // 1. Extract tile distribution from distributed tensor - using TileDstr = typename DistributedTensor::TileDistribution; - - // 2. Upgrade simple tile window to one with distribution - auto tile_window = make_tile_window( - tile_window_tmp.get_bottom_tensor_view(), - tile_window_tmp.get_window_lengths(), - tile_window_tmp.get_window_origin(), - TileDstr{} // Add distribution info - ); - - // 3. Store using vectorized writes - tile_window.store(dstr_tensor); -} -``` - -### Memory Flow: - -``` -VGPRs (Registers) DRAM (Global Memory) -┌─────────────────────┐ ┌─────────────────────┐ -│ c_block_tile │ │ C matrix │ -│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │ -│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │ -│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │ -│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │ -│ └───┴───┴───┴───┘ │ │ │ │ │ -│ Distributed across │ │ └───────────────┘ │ -│ threads/warps │ │ Origin: (0, 0) │ -└─────────────────────┘ └─────────────────────┘ - -Each thread writes its portion using vector stores (e.g., float4) -``` - -### Store Optimization: - -The `store_tile` function: -- Uses **vectorized stores** (write multiple elements at once) -- Ensures **coalesced memory access** (adjacent threads write adjacent memory) -- Respects **tile distribution** (each thread knows what data it owns) -- Handles **out-of-bounds** checking (for partial tiles at boundaries) - ---- - -## Complete Flow Visualization - -Let's trace the complete flow for **Block 0** (other blocks follow the same pattern): - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ Step 1: Calculate Tile Coverage │ -│ ┌─────────────────────────────────────────────────────────────┐ │ -│ │ M=512, N=256, K=64 │ │ -│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │ -│ │ num_tile_m = ceil(512/256) = 2 │ │ -│ │ num_tile_n = ceil(256/128) = 2 │ │ -│ │ Total blocks needed = 2 × 2 = 4 blocks │ │ -│ └─────────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Step 2: Map Block to Tile (Block 0 example) │ -│ ┌─────────────────────────────────────────────────────────────┐ │ -│ │ Block ID: 0 │ │ -│ │ Tile coordinates: (0, 0) - top-left tile │ │ -│ │ Tile origin: (0, 0) │ │ -│ │ │ │ -│ │ (Blocks 1,2,3 get different tile coordinates) │ │ -│ └─────────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Step 3: Create Tile Windows │ -│ ┌─────────────────────────────────────────────────────────────┐ │ -│ │ a_block_window: 256×32 starting at (0,0) over A │ │ -│ │ b_block_window: 128×32 starting at (0,0) over B │ │ -│ │ Windows initially cover K columns 0-31 │ │ -│ └─────────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Step 4: Execute Block Pipeline (2 K iterations) │ -│ ┌─────────────────────────────────────────────────────────────┐ │ -│ │ Allocate shared memory (LDS) │ │ -│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │ -│ │ │ │ -│ │ K Iteration 0 (K=0-31): │ │ -│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ -│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ -│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │ -│ │ └─ Move windows: {0, 32} │ │ -│ │ │ │ -│ │ K Iteration 1 (K=32-63): │ │ -│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ -│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ -│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │ -│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │ -│ │ │ │ -│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │ -│ └─────────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Step 5: Store Results │ -│ ┌─────────────────────────────────────────────────────────────┐ │ -│ │ Create c_window: 256×128 starting at (0,0) over C │ │ -│ │ store_tile(c_window, c_block_tile) │ │ -│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │ -│ │ │ │ -│ │ Block 0 writes to C[0:255, 0:127] │ │ -│ │ (Other blocks write to their respective regions) │ │ -│ └─────────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - -All 4 blocks execute in parallel, each computing its assigned 256×128 tile! -``` - ---- - -## Key Concepts Summary - -### 1. **Tile Coverage** -- Determines how many thread blocks are needed -- Each block processes one tile of the output matrix C -- Calculated as `ceil(dimension / tile_size)` - -### 2. **Block-to-Tile Mapping** -- Maps linear block ID to 2D tile coordinates -- Uses column-major ordering for better memory coalescing -- Each block knows which tile it's responsible for - -### 3. **Tile Windows** -- **Sliding windows** over larger tensor views -- Define a rectangular region with fixed size and movable origin -- Provide efficient, structured access to tensor data -- Can be moved to access different regions (e.g., for K iterations) - -### 4. **Memory Hierarchy** -- **DRAM (Global)**: Largest, slowest - stores full matrices -- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access -- **VGPRs (Registers)**: Smallest, fastest - performs computation - -### 5. **Data Flow** -``` -DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM - ↑ ↓ - A, B matrices C matrix -``` - ---- - -## Next Steps - -The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore: -- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed -- **Warp-level pipeline**: How warps perform MFMA operations -- **Memory optimization**: LDS usage, bank conflicts, coalescing - -The host level provides the **orchestration**, while the block and warp levels provide the **execution**! - diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md b/tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md deleted file mode 100644 index 7cd0d06fc5..0000000000 --- a/tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md +++ /dev/null @@ -1,464 +0,0 @@ -# PracticeGemmKernel: Understanding the Kernel Entry Point - -This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views. - -## Overview - -The `PracticeGemmKernel` is a templated struct that: -1. Takes raw device memory pointers for matrices A, B, and C -2. Wraps them into **tensor views** - logical, structured views over physical memory -3. Dispatches to the host-level pipeline for computation - -```cpp -template -struct PracticeGemmKernel -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - static constexpr index_t kBlockSize = 256; - - CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, - const typename Problem::BDataType* p_b, - typename Problem::CDataType* p_c, - const index_t M, - const index_t N, - const index_t K, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c) const - { - // Step 1: Create tensor views over raw memory - auto a_dram = make_naive_tensor_view( - p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); - - auto b_dram = make_naive_tensor_view( - p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); - - const auto c_dram = make_naive_tensor_view( - p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); - - // Step 2: Dispatch to host-level pipeline - PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); - } -}; -``` - ---- - -## What are Tensor Views? - -A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor. - -### Key Components of a Tensor View: - -1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers) -2. **Raw Pointer**: Points to the actual data in memory -3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A) -4. **Strides**: How to navigate through memory to access elements -5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction -6. **Guaranteed Vector Stride**: The stride of those vectorizable elements - ---- - -## The Memory Abstraction Hierarchy - -CK Tile uses a three-layer abstraction to go from raw memory to structured tensors: - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Layer 3: TENSOR VIEW │ -│ ┌─────────────────────────────────────────────────────────┐ │ -│ │ • Logical multi-dimensional structure │ │ -│ │ • Shape: (M, K) = (256, 32) │ │ -│ │ • Strides: (32, 1) for row-major layout │ │ -│ │ • Provides: operator[], coordinate-based access │ │ -│ │ • Knows: How to map (i,j) → linear offset │ │ -│ └─────────────────────────────────────────────────────────┘ │ -│ ↓ wraps │ -│ ┌─────────────────────────────────────────────────────────┐ │ -│ │ Layer 2: BUFFER VIEW │ │ -│ │ ┌─────────────────────────────────────────────────────┐ │ │ -│ │ │ • Linear view of memory │ │ │ -│ │ │ • Pointer: p_data_ → device memory │ │ │ -│ │ │ • Size: Total number of elements │ │ │ -│ │ │ • Address space: global/LDS/generic │ │ │ -│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │ -│ │ └─────────────────────────────────────────────────────┘ │ │ -│ └─────────────────────────────────────────────────────────┘ │ -│ ↓ wraps │ -│ ┌─────────────────────────────────────────────────────────┐ │ -│ │ Layer 1: RAW PHYSICAL MEMORY │ │ -│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │ -│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │ -│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │ -│ │ ↑ │ │ -│ │ p_a (raw pointer from hipMalloc) │ │ -│ └─────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────┘ -``` - ---- - -## Deep Dive: `make_naive_tensor_view` - -Let's break down the function call for matrix A: - -```cpp -auto a_dram = make_naive_tensor_view( - p_a, // Raw pointer to device memory - make_tuple(M, K), // Shape: (256, 32) - make_tuple(stride_a, 1), // Strides: (32, 1) - row-major - number<8>{}, // Guaranteed vector length - number<1>{} // Guaranteed vector stride -); -``` - -### Function Signature: - -```cpp -template -CK_TILE_HOST_DEVICE constexpr auto -make_naive_tensor_view(DataType* __restrict__ p, - const tuple& lengths, - const tuple& strides, - number = number<-1>{}, - number = number<-1>{}) -{ - // Step 1: Create tensor descriptor (shape + stride information) - auto desc = make_naive_tensor_descriptor(lengths, - strides, - number{}, - number{}); - - // Step 2: Create buffer view (pointer + size + address space) - auto buffer_view = - make_buffer_view(p, desc.get_element_space_size()); - - // Step 3: Combine into tensor view - return tensor_view{buffer_view, desc}; -} -``` - ---- - -## Parameter Breakdown - -### 1. **Template Parameter: `address_space_enum::global`** - -Specifies where the memory lives: -- `global`: GPU global memory (DRAM) - slowest but largest -- `lds`: Local Data Share (shared memory) - fast, limited size -- `generic`: Generic address space -- `vgpr`: Vector General Purpose Registers - fastest, smallest - -In our case, `global` means the data is in GPU DRAM. - -### 2. **`p_a` - Raw Pointer** - -The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data. - -### 3. **`make_tuple(M, K)` - Shape/Lengths** - -Defines the logical dimensions of the tensor: -- For matrix A: `(256, 32)` means 256 rows, 32 columns -- This is the **logical view**, independent of how data is physically laid out - -### 4. **`make_tuple(stride_a, 1)` - Strides** - -Defines how to navigate through memory: -- **Stride for dimension 0 (rows)**: `stride_a = K = 32` - - To move to the next row, skip 32 elements -- **Stride for dimension 1 (columns)**: `1` - - To move to the next column, skip 1 element - -**Row-major layout example:** -``` -Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...] - ↑ ↑ - Row 0 starts here Row 1 starts here (offset = 32) - -To access element A[i][j]: - offset = i * stride_a + j * 1 - = i * 32 + j -``` - -### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length** - -This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."** - -#### Why is this important? - -Modern GPUs can load multiple elements in one instruction (vectorized loads): -- `float4`: Load 4 floats at once -- `float8`: Load 8 floats at once (if supported) - -By specifying `number<8>{}`, we're telling the system: -- "You can safely use vector loads of up to 8 elements" -- "The memory alignment and layout support this" - -**Example:** -```cpp -// Without vectorization (slow): -for (int j = 0; j < 8; j++) { - data[j] = memory[offset + j]; // 8 separate loads -} - -// With vectorization (fast): -float8 vec = *reinterpret_cast(&memory[offset]); // 1 load! -``` - -### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride** - -This specifies the **stride between consecutive vectorizable elements** in the last dimension. - -- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)" -- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively - -**Why does this matter?** - -For efficient vectorized loads, elements must be: -1. **Contiguous** (stride = 1) ✓ -2. **Aligned** properly in memory -3. **Within the same cache line** (ideally) - -If the stride were `2`, it would mean: -``` -A[i][0] is at offset 0 -A[i][1] is at offset 2 (not 1!) -A[i][2] is at offset 4 -``` -This would prevent efficient vectorization. - ---- - -## What is a Buffer View? - -A **buffer view** is the middle layer between raw memory and tensor view. It provides: - -### Core Responsibilities: - -1. **Memory Management** - - Holds the raw pointer: `T* p_data_` - - Tracks buffer size: `BufferSizeType buffer_size_` - - Knows the address space: `global`, `lds`, etc. - -2. **Vectorized Access** - ```cpp - template - CK_TILE_DEVICE VectorType get(index_t offset); - ``` - - Provides efficient vector loads/stores - - Handles alignment requirements - -3. **Bounds Checking** (optional) - ```cpp - template - CK_TILE_DEVICE auto get(index_t i, index_t linear_offset); - ``` - - Can optionally check if access is within bounds - - Returns invalid value (default 0) for out-of-bounds access - -4. **Address Space Awareness** - - Uses different load/store instructions based on address space - - Global memory: `global_load`, `global_store` - - LDS: `ds_read`, `ds_write` - -### Buffer View Structure: - -```cpp -template -struct buffer_view -{ - T* p_data_; // Raw pointer - BufferSizeType buffer_size_; // Total elements - remove_cvref_t invalid_element_value_; // Value for OOB access - - // Access operators - const T& operator[](index_t i) const; // Read - T& operator()(index_t i); // Write - - // Vectorized access - template - VectorType get(index_t offset); -}; -``` - ---- - -## Visual Example: Matrix A Memory Layout - -Let's visualize how matrix A (256×32, fp16) is organized: - -### Raw Physical Memory (Linear): -``` -GPU DRAM Address Space: -┌─────────────────────────────────────────────────────────────────┐ -│ Byte 0 │ -│ ↓ │ -│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │ -│ ↑ ↑ │ -│ Row 0 (32 elements) Row 1 (32 elements) │ -│ │ -│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │ -└─────────────────────────────────────────────────────────────────┘ - ↑ - p_a (raw pointer) -``` - -### Buffer View Layer: -``` -buffer_view -┌─────────────────────────────────────────────────────────────────┐ -│ p_data_ = p_a │ -│ buffer_size_ = 256 × 32 = 8,192 elements │ -│ address_space = global (DRAM) │ -│ │ -│ Provides: │ -│ • Linear indexing: buffer_view[i] → element at offset i │ -│ • Vectorized loads: get(offset) → load 4 fp16s at once│ -│ • Bounds checking: is offset < buffer_size_? │ -└─────────────────────────────────────────────────────────────────┘ -``` - -### Tensor View Layer: -``` -tensor_view -┌─────────────────────────────────────────────────────────────────┐ -│ Shape: (256, 32) │ -│ Strides: (32, 1) │ -│ Guaranteed vector length: 8 │ -│ Guaranteed vector stride: 1 │ -│ │ -│ Logical 2D View: │ -│ Col: 0 1 2 ... 31 │ -│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│ -│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │ -│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │ -│ ... │ -│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │ -│ │ -│ Provides: │ -│ • Multi-dimensional indexing: A[i][j] │ -│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │ -│ • Tile window creation: Extract sub-tensors │ -└─────────────────────────────────────────────────────────────────┘ -``` - ---- - -## Complete Flow: Raw Memory → Tensor View - -Let's trace the complete transformation for matrix A: - -### Step 1: Kernel Launch (Host Side) -```cpp -// On host: Allocate device memory -hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer - -// Launch kernel -kernel<<>>(p_a, p_b, p_c, M, N, K, ...); -``` - -### Step 2: Inside Kernel (Device Side) -```cpp -// Receive raw pointer -const fp16_t* p_a; // Points to GPU DRAM - -// Step 2a: Create tensor descriptor -auto desc = make_naive_tensor_descriptor( - make_tuple(256, 32), // Shape - make_tuple(32, 1), // Strides - number<8>{}, // Vector length - number<1>{} // Vector stride -); -// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8" - -// Step 2b: Create buffer view -auto buffer_view = make_buffer_view( - p_a, // Raw pointer - 256 * 32 // Total elements -); -// buffer_view now wraps p_a with size and address space info - -// Step 2c: Create tensor view -auto a_dram = tensor_view{buffer_view, desc}; -// a_dram now provides structured, multi-dimensional access to p_a -``` - -### Step 3: Using the Tensor View -```cpp -// Access element A[i][j] -auto value = a_dram[make_tuple(i, j)]; - -// Create a tile window (sub-tensor) -auto tile = make_tile_window( - a_dram, - make_tuple(16, 16), // 16×16 tile - make_tuple(0, 0) // Starting at origin -); - -// Load tile into registers with vectorization -auto tile_data = load_tile(tile); // Uses vector loads internally! -``` - ---- - -## Why This Abstraction? - -### Benefits: - -1. **Type Safety**: Can't accidentally access wrong dimensions -2. **Performance**: Compiler knows about vectorization opportunities -3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers) -4. **Maintainability**: Logical structure separate from physical layout -5. **Optimization**: Guaranteed vector properties enable aggressive optimizations - -### Example: Without Tensor Views (Manual Indexing) -```cpp -// Ugly, error-prone, hard to optimize: -for (int i = 0; i < 16; i++) { - for (int j = 0; j < 16; j++) { - float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j]; - // Hope the compiler vectorizes this? 🤞 - } -} -``` - -### Example: With Tensor Views (Clean, Optimized) -```cpp -// Clean, safe, automatically vectorized: -auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin); -auto tile_data = load_tile(tile); // Vectorized loads guaranteed! -``` - ---- - -## Summary - -The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction: - -1. **Raw Memory**: Linear array of bytes in GPU DRAM -2. **Buffer View**: Adds size, address space, and vectorized access -3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing - -This abstraction enables: -- ✅ Clean, readable code -- ✅ Type-safe multi-dimensional access -- ✅ Automatic vectorization -- ✅ Flexible memory space handling -- ✅ Efficient tile-based computation - -The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation! - diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/README.md b/tutorial/ck_tile/gemm/01_naive_gemm/README.md index 13a117ae80..8a1a57eb40 100644 --- a/tutorial/ck_tile/gemm/01_naive_gemm/README.md +++ b/tutorial/ck_tile/gemm/01_naive_gemm/README.md @@ -1,150 +1,115 @@ -# CK Tile Practice GEMM Example +# CK Tile Naive GEMM Tutorial -This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system. +A minimal GEMM (`C = A × B`) using the CK Tile API. No optimizations — just the +core data flow through the three-level hierarchy: host → block → warp. -## CK Tile API Structure +## Key Terms -In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**: +| Term | What it is | +|------|-----------| +| **Problem** | Shape, data types, and layout of the GEMM matrices | +| **Policy** | How data and computation are mapped to threads (tile distributions, warp configs) | +| **Pipeline** | The loop that moves data through DRAM → VGPRs → LDS → MFMA and accumulates C | +| **Epilogue** | Post-GEMM work (e.g. activation, scaling). Not used in this tutorial | -1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices -2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads -3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue) +## Execution Hierarchy -## Overview - -This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing: - -- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch -- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM -- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory -- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation - -## Problem Setup and Data Flow - -### Problem Size Configuration -We set the problem size using the M, N and K variables: -```cpp -ck_tile::index_t M = 1024; // Number of rows in A and C -ck_tile::index_t N = 512; // Number of columns in B and C -ck_tile::index_t K = 256; // Number of columns in A, rows in B +``` +practice_gemm.cpp ← host: parse args, allocate, launch, verify + └─ grid_gemm.hpp ← host-level: block-to-tile mapping, create tile windows + └─ block_gemm_pipeline_agmem_bgmem_creg.hpp + │ ← block-level: loop over K, DRAM→VGPR→LDS, call warp GEMM + └─ block_gemm_asmem_bsmem_creg.hpp + ← warp-level: LDS→VGPR, MFMA m32n32k8, accumulate C ``` -### Host Matrix Creation -Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory: -```cpp -// Host tensors with proper strides -ck_tile::HostTensor a_host(a_lengths, a_strides); // M × K -ck_tile::HostTensor b_host(b_lengths, b_strides); // N × K -ck_tile::HostTensor c_host(c_lengths, c_strides); // M × N - -// Initialize with random data -ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); -ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); - -// Allocate device memory and transfer data -ck_tile::DeviceMem a_device(a_host); -a_device.ToDevice(a_host.data()); +**Data flow per K-iteration:** +``` +A,B in DRAM ──load_tile──► VGPRs ──store_tile──► LDS ──sync──► warp GEMM (MFMA) ──► C in VGPRs ``` -### PracticeGemmShape Configuration -A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile: +After all K-iterations, C is stored back to DRAM. -```cpp -using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block -using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave -``` -- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix. +## File Guide +| File | Role | +|------|------| +| `practice_gemm.cpp` | Entry point: sizes, host tensors, kernel launch, verification | +| `practice_gemm.hpp` | Composes `GridGemmProblem`, `BlockGemmPipelineProblem`, and `Gemm` struct | +| `reference_gemm.hpp` | CPU reference for correctness checking | +| `grid_gemm.hpp` | Host-level pipeline: maps `blockIdx` to tile coordinates, creates A/B/C tile windows | +| `block_gemm_pipeline_agmem_bgmem_creg.hpp` | Block-level pipeline: K-loop, DRAM→LDS data movement | +| `block_gemm_pipeline_agmem_bgmem_creg_policy.hpp` | Policy: A/B DRAM tile distributions, LDS descriptors | +| `block_gemm_asmem_bsmem_creg.hpp` | Warp-level: reads A/B from LDS, runs MFMA, accumulates C | +| `block_gemm_asmem_bsmem_creg_policy.hpp` | Policy: WarpGemm type selection (standard vs transposed C) | -- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile. -- BlockTiles are further subdivided into WarpTiles. -- WarpTiles over A and B similarly work together to calculate the WarpTile of C. +## Tile Sizes -### Problem and Policy Composition -```cpp -// A Problem is composed from Shape and info about the data -using PracticeGemmHostProblem = ck_tile:: - PracticeGemmHostProblem; +From `practice_gemm.cpp` (fp16, `BlockSize=256`): -// A Policy is created describing data-to-thread mapping -using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; +| Matrix | Block tile | Description | +|--------|-----------|-------------| +| A | 256 × 32 | M × K per block | +| B | 128 × 32 | N × K per block | +| C | 256 × 128 | M × N per block (accumulated in registers) | -// A Kernel is then composed of Problem and Policy -using gemm_kernel = ck_tile::PracticeGemmKernel; +Each block tile is further split across 4 warps (MWarp=4, NWarp=1). +The warp-level MFMA instruction is `m32n32k8`. + +## Tile Distributions + +The policy files define how threads map to tile elements: + +**A and B DRAM loads** (`block_gemm_pipeline_agmem_bgmem_creg_policy.hpp`): +- Factor M (or N) into `M0 × M1 × M2`, K into `K0 × K1` +- `P0 = warp_id → M1`, `P1 = lane_id → M2 × K0` (merged for coalescing) +- `Y0 = M0` (iterations), `Y1 = K1` (vector load width = 8 for fp16) +- See the `tile_distribution/` tutorials for worked examples with these exact shapes + +**C register layout** (`block_gemm_asmem_bsmem_creg_policy.hpp`): +- Determined by the WarpGemm type (MFMA instruction output mapping) +- Standard: M-dimension in Hs[0], N-dimension in Hs[1] +- Transposed: swaps M/N dimensions, changes which lanes hold which C elements + +## Transposed C Distribution Switch + +The macro `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` (default: 1) selects between +two WarpGemm variants: + +| Value | WarpGemm | C layout | +|-------|----------|----------| +| 1 (default) | `WarpGemmMfma*TransposedCDistribution` | Swapped A/B in MFMA, transposed C register layout | +| 0 | `WarpGemmMfma*` | Standard MFMA, standard C register layout | + +To build with the standard (non-transposed) variant, pass the define via compiler flags: +```bash +cmake -DCMAKE_CXX_FLAGS="-DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0" .. ``` -### Kernel Launch -`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`: -```cpp -float ave_time = ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, true, 0, 0, 1}, - ck_tile::make_kernel( - gemm_kernel{}, // Kernel composed of Problem + Policy - kGridSize, // Grid dimensions - kBlockSize, // Block dimensions - 0, // Dynamic shared memory - // Kernel arguments: device buffers and problem dimensions - a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(), - M, N, K, stride_a, stride_b, stride_c)); -``` - -### Result Verification -The results from the kernel are compared with results from CPU based computation function: -```cpp -// CPU reference implementation -ck_tile::HostTensor c_host_ref(c_lengths, c_strides); -reference_basic_gemm(a_host, b_host, c_host_ref); - -// Device results -ck_tile::HostTensor c_host_dev(c_lengths, c_strides); - -// Verify correctness -bool pass = ck_tile::check_err(c_host_dev, c_host_ref); -``` - -### Runtime Flow - -The main program (`practice_gemm.cpp`) is the entry point for the runtime flow: - -```cpp -int main() -{ - // 1. Define data types and problem sizes - using ADataType = ck_tile::half_t; - ck_tile::index_t M = 2048, N = 1024, K = 512; - - // 2. Create host tensors and initialize - ck_tile::HostTensor a_host(a_lengths, a_strides); - ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); - - // 3. Allocate device memory and transfer data - ck_tile::DeviceMem a_device(a_host); - - // 4. Configure tile shapes - using BlockTile = ck_tile::sequence<256, 128, 32>; - using WaveTile = ck_tile::sequence<16, 16, 16>; - - // 5. Launch kernel - using gemm_kernel = ck_tile::PracticeGemmKernel; - float ave_time = ck_tile::launch_kernel(/*...*/); - - // 6. Verify results - bool pass = verify_results(a_host, b_host, c_host); - - // 7. Print performance metrics - print_performance_metrics(ave_time, M, N, K); -} -``` +Both variants produce correct results — they differ only in how C elements are +distributed across thread registers, which affects downstream store coalescing. ## Building and Running ```bash -# From composable_kernel root directory -mkdir build && cd build -../script/cmake-ck-dev.sh ../ -make tile_tutorial_naive_gemm -j +cd /projects/composablekernel/build -# Run with sample sizes +# Configure (first time) +../script/cmake-ck-dev.sh ../ + +# Build +make tile_tutorial_naive_gemm -j +# or: ninja tile_tutorial_naive_gemm + +# Run (default: M=3328, N=4096, K=4096) ./bin/tile_tutorial_naive_gemm + +# Custom sizes (positional args: verification M N K) +./bin/tile_tutorial_naive_gemm 0 1024 512 256 ``` -This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework. + +## Reference + +- Tile distribution encoding: `include/ck_tile/core/tensor/tile_distribution_encoding.hpp` +- MFMA warp gemm: `include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp` +- Production GEMM pipeline: `include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp` diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md b/tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md deleted file mode 100644 index 275d1a1c12..0000000000 --- a/tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md +++ /dev/null @@ -1,312 +0,0 @@ -# Tile Distribution: Mapping Threads to Data - -## Overview - -**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks. - -## The Problem - -Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine: -- Which threads load which elements. -- How the threads are organized into warps. -- The number of times each warp repeats its pattern. -- The number of elements each thread can load in a single vector instruction. - ---- - -## Bottom-Up Construction Approach - -### Step 1: Determine K Dimension Layout - -**Start with the innermost dimension (K) for memory coalescing:** - -```cpp -constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load) -constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension -``` - -**Example (with fp16):** -- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction -- `kKPerBlock = 32` -- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension - -**Visual:** -``` -K dimension (32 elements): -Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31] - K1=8 K1=8 K1=8 K1=8 -├──────────────────────────────────────────────────────────────┤ - K0=4 threads -``` - ---- - -### Step 2: Determine M Dimension Layout - -**Now partition the M dimension hierarchically:** - -#### Level 1: Threads per Warp in M (M2) - -```cpp -constexpr index_t M2 = get_warp_size() / K0; -``` - -- Warp size = 64 threads -- K dimension already uses `K0 = 4` threads per row -- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension - -**Visual (Single Warp):** -``` - K dimension (4 threads) - ┌─────┬─────┬─────┬─────┐ - 0 │ T0 │ T1 │ T2 │ T3 │ - 1 │ T4 │ T5 │ T6 │ T7 │ - 2 │ T8 │ T9 │ T10 │ T11 │ -M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows - ...│ ... │ ... │ ... │ ... │ (M2=16) - 15 │ T60 │ T61 │ T62 │ T63 │ - └─────┴─────┴─────┴─────┘ - One Warp = 64 threads -``` - -#### Level 2: Warps per Block (M1) - -```cpp -constexpr index_t M1 = kBlockSize / get_warp_size(); -``` - -- `kBlockSize = 256` threads per block -- `M1 = 256 / 64 = 4` → We have 4 warps per block - -**Visual (4 Warps):** -``` - Warp 0 (rows 0-15) - Warp 1 (rows 16-31) - Warp 2 (rows 32-47) - Warp 3 (rows 48-63) - ↑ - M1 = 4 warps cover 64 rows total -``` - -#### Level 3: Repetitions (M0) - -```cpp -constexpr index_t M0 = kMPerBlock / (M2 * M1); -``` - -- `kMPerBlock = 256` rows to cover -- `M2 * M1 = 16 * 4 = 64` rows covered by all warps -- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times - -**Visual (Complete Block):** -``` -┌──────────────┐ -│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ... -│ (rows 0-63) │ -├──────────────┤ -│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ... -│ (rows 64-127)│ -├──────────────┤ -│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ... -│(rows 128-191)│ -├──────────────┤ -│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ... -│(rows 192-255)│ -└──────────────┘ - M0 = 4 iterations -``` - ---- - -## The Tile Distribution Encoding - -Now we can construct the distribution: - -```cpp -tile_distribution_encoding< - sequence<1>, // [1] Replication - tuple, sequence>, // [2] Hierarchy - tuple, sequence<1, 2>>, // [3] Parallelism: - tuple, sequence<2, 0>>, // [3] Parallelism - sequence<1, 2>, // [4] Yield - sequence<0, 1> // [4] Yield -> -``` - -### [1] Replication: `sequence<1>` - -Defines how many times warp patterns are replicated: -- `1` = Each warp has a unique pattern (no replication) -- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing -- `4` = All warps do the same thing - -In our case: `1` means no replication (each warp is independent). - ---- - -### [2] Hierarchy: The Multi-Level Structure - -```cpp -tuple, sequence> - └───────┬──────────┘ └──────┬────────┘ - M dimension K dimension -``` - -**Concrete values:** -- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp) -- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread) - ---- - -### [3] Parallelism: Addressing the Hierarchy - -**The key insight:** Read the tuples **vertically** to understand indexing! - -```cpp -tuple, sequence<1, 2>> -tuple, sequence<2, 0>> -``` - -#### Reading Pattern - -**Column 1 (Dimension 0 = M):** -``` -sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension) -sequence<1> -``` - -**Column 2 (Dimension 1 = K):** -``` -sequence<1, 2> -sequence<2, 0> -``` -[1,2] M2=threads/warp in M dimension -[2,0] K0=threads/warp in K dimension - ---- - -### [4] Yield Sequences: Output Ordering - -```cpp -sequence<1, 2> -sequence<0, 1> - -[1,0] means M0=repetitions/warp in M dimension -[2,1] means K1=elements/thread in K dimension -``` ---- - -## Complete Example: Thread 25 in Warp 0 - -Let's trace where **Thread 25** in **Warp 0** reads data: - -### Thread Coordinates -- Thread ID in warp: 25 -- Warp ID in block: 0 - -### Decompose Thread 25 -``` -Thread 25 in a 2D layout (M2=16, K0=4): -Row index: 25 / 4 = 6 -Col index: 25 % 4 = 1 -``` - -### M Position (Row) -``` -M0 iteration: 0 (first iteration) -M1 warp: 0 (warp 0) -M2 thread: 6 (6th row in warp) -→ M position = 0*64 + 0*16 + 6 = 6 -``` - -### K Position (Column) -``` -K0 thread: 1 (column group 1) -K1 elements: 8 (will load 8 consecutive elements) -→ K position = 1*8 + [0-7] = elements 8-15 -``` - -**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements). - ---- - -## Why This Matters - -### 1. **Memory Coalescing** -- Consecutive threads access consecutive memory → efficient global memory access -- K dimension uses K1=8 for vectorized loads - -### 2. **Warp Efficiency** -- All 64 threads in a warp are utilized -- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads - -### 3. **Scalability** -- M0 repetitions allow handling larger tiles -- Same pattern scales to different sizes - -### 4. **Register Allocation** -- Each thread knows exactly how many elements it will hold -- Compiler can allocate registers optimally - ---- - -## Summary Table - -| Parameter | Value | Meaning | -|-----------|-------|---------| -| **K1** | 8 | Elements per thread (vector width) | -| **K0** | 4 | Threads along K per row | -| **M2** | 16 | Threads along M per warp | -| **M1** | 4 | Warps per block | -| **M0** | 4 | Repetitions of warp pattern | -| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) | -| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) | -| **Elements/Thread** | 32 | M0×K1 = 4×8 | - ---- - -## Visualization: Complete Thread Block - -``` -Block Tile: 256×32 - - K dimension (32 elements) - ├─────────────────────┤ - 0 ┌──────────────────────┐ ┐ - 16 │ Warp 0 │ │ - 32 │ Warp 1 │ │ Iteration 0 - 48 │ Warp 2 │ │ (M0=0) - 64 │ Warp 3 │ ┘ - 80 ├──────────────────────┤ ┐ - 96 │ Warp 0 │ │ - 112 │ Warp 1 │ │ Iteration 1 - 128 │ Warp 2 │ │ (M0=1) - 144 │ Warp 3 │ ┘ - 160 ├──────────────────────┤ ┐ - 176 │ Warp 0 │ │ - 192 │ Warp 1 │ │ Iteration 2 - 208 │ Warp 2 │ │ (M0=2) - 224 │ Warp 3 │ ┘ - 240 ├──────────────────────┤ ┐ - 256 │ Warp 0 │ │ - │ Warp 1 │ │ Iteration 3 - │ Warp 2 │ │ (M0=3) - │ Warp 3 │ ┘ - └──────────────────────┘ - -Each warp processes 16 rows × 32 cols = 512 elements -Each iteration processes 64 rows × 32 cols = 2048 elements -Total: 4 iterations × 2048 = 8192 elements ✓ -``` - ---- - -## Key Takeaways - -1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy -2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels -3. **Replication controls redundancy**: How many warps share the same pattern -4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping - -This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping! - diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md b/tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md deleted file mode 100644 index d0b8400b9c..0000000000 --- a/tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md +++ /dev/null @@ -1,506 +0,0 @@ -# Practice GEMM: Step-by-Step Code Walkthrough - -This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API. - -## Overview - -We'll implement `C = A × B` where: -- `A` is an `M × K` matrix -- `B` is an `N × K` matrix (note: transposed layout) -- `C` is an `M × N` matrix - -The implementation uses a hierarchical tiling strategy with two levels: -1. **Block Tiles**: Processed by thread blocks -2. **Wave Tiles**: Processed by warps (wavefronts) within blocks - ---- - -## Step 1: Define Data Types - -```cpp -using ADataType = ck_tile::half_t; -using BDataType = ck_tile::half_t; -using CDataType = float; -using AccDataType = float; -``` - -**What's happening:** -- We use `half_t` (FP16) for input matrices A and B. -- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy -- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity ---- - -## Step 2: Define Problem Size - -```cpp -ck_tile::index_t M = 512; -ck_tile::index_t N = 256; -ck_tile::index_t K = 64; -ck_tile::index_t verification = 1; - -ck_tile::index_t stride_a = K; -ck_tile::index_t stride_b = K; -ck_tile::index_t stride_c = N; -``` - -**What's happening:** -- `M = 512`: Number of rows in A and C -- `N = 256`: Number of columns in B and C -- `K = 64`: Inner dimension (columns of A, rows of B) -- Strides define memory layout (row-major for A and C, transposed for B) - -**Memory Layout:** -``` -Matrix A (M×K): Matrix B (N×K): Matrix C (M×N): -[512 rows] [256 rows] [512 rows] -[64 cols] [64 cols] [256 cols] -stride = K stride = K stride = N -``` - ---- - -## Step 3: Create Host Tensors - -```cpp -auto a_lengths = std::array{M, K}; -auto b_lengths = std::array{N, K}; -auto c_lengths = std::array{M, N}; - -auto a_strides = std::array{stride_a, 1}; -auto b_strides = std::array{stride_b, 1}; -auto c_strides = std::array{stride_c, 1}; - -ck_tile::HostTensor a_host(a_lengths, a_strides); -ck_tile::HostTensor b_host(b_lengths, b_strides); -ck_tile::HostTensor c_host(c_lengths, c_strides); -``` - -**What's happening:** -- We create three tensors on the host (CPU) memory -- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`) -- `HostTensor` is a CK Tile utility class that manages CPU memory - -**Stride explanation:** -- For A: `stride_a = K` means moving to the next row requires skipping K elements -- For B: `stride_b = K` means B is stored in transposed format -- For C: `stride_c = N` means row-major layout - ---- - -## Step 4: Initialize Tensors with Random Data - -```cpp -ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); -ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); -c_host.SetZero(); -``` - -**What's happening:** -- A and B are filled with random values in the range [-5.0, 5.0] -- C is initialized to zero (will store the output) - -**Optional: Print Tensor Contents** -```cpp -// Commented out in the code, but available for debugging: -// a_host.print_first_n(10); // Print first 10 elements of A -``` - -The `print_first_n()` helper function can display tensor contents for debugging purposes. - ---- - -## Step 5: Allocate Device Memory and Transfer Data - -```cpp -ck_tile::DeviceMem a_device(a_host); -ck_tile::DeviceMem b_device(b_host); -ck_tile::DeviceMem c_device(c_host); -``` - -**What's happening:** -- `DeviceMem` allocates GPU memory matching the size of host tensors -- The constructor **automatically transfers data from host to device** -- This is a convenience wrapper around `hipMalloc` and `hipMemcpy` - -**Memory Flow:** -``` -CPU (Host) GPU (Device) -┌─────────┐ ┌─────────┐ -│ a_host │ ────────> │a_device │ -│ b_host │ ────────> │b_device │ -│ c_host │ ────────> │c_device │ -└─────────┘ └─────────┘ -``` - ---- - -## Step 6: Configure Hierarchical Tiling - -```cpp -using BlockTile = ck_tile::sequence<256, 128, 32>; -using WaveTile = ck_tile::sequence<16, 16, 16>; -``` - -**What's happening:** -- We define a two-level tiling hierarchy for the GEMM computation - -### Block Tile (256 × 128 × 32) -- **256**: M dimension per block (rows of A and C) -- **128**: N dimension per block (columns of B and C) -- **32**: K dimension per block (inner dimension) -- Each block tile is processed by one **thread block** (256 threads) - -### Wave Tile (16 × 16 × 16) -- **16 × 16**: Output tile dimensions (M × N) per warp iteration -- **16**: K dimension per warp iteration -- Each wave tile is processed by one **warp** (64 threads on AMD GPUs) - -**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile - -**Important Note:** -In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem. - -### Tiling Visualization: - -#### Matrix A (M × K = 256 × 32): -``` -┌─────────────────────────────────────┐ -│ One Block Tile (256 × 32) │ -│ ┌────┬────┐ │ -│ │16×│16× │ ← Wave tiles (16×16) │ -│ │ 16│ 16 │ in M×K space │ -│ ├────┼────┤ │ -│ │ │ │ │ -│ ├────┼────┤ │ -│ │ .. │ .. │ 16 tiles in M │ -│ ├────┼────┤ 2 tiles in K │ -│ │ │ │ │ -│ └────┴────┘ │ -│ │ -└─────────────────────────────────────┘ -``` - -#### Matrix B (N × K = 128 × 32): -``` -┌──────────────────────────────┐ -│ One Block Tile (128 × 32) │ -│ ┌────┬────┐ │ -│ │16×│16× │ ← Wave tiles │ -│ │ 16│ 16 │ (16×16) │ -│ ├────┼────┤ │ -│ │ │ │ │ -│ ├────┼────┤ 8 tiles in N │ -│ │ .. │ .. │ 2 tiles in K │ -│ ├────┼────┤ │ -│ │ │ │ │ -│ └────┴────┘ │ -└──────────────────────────────┘ -``` - -#### Matrix C (M × N = 256 × 128) - Output: -``` -┌─────────────────────────────────────────────────┐ -│ One Block Tile (256 × 128) │ -│ │ -│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │ -│ │16× │ │ │ │ │ │ │ │ │ -│ │ 16 │ │ │ │ │ │ │ │ │ -│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ -│ │ │ │ │ │ │ │ │ │ │ -│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ -│ │ │ │ │ │ │ │ │ │ │ -│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ -│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │ -│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ -│ │ │ │ │ │ │ │ │ │ │ -│ └────┴────┴────┴────┴────┴────┴────┴────┘ │ -│ │ -│ 16 wave tiles in M direction │ -│ 8 wave tiles in N direction │ -│ Total: 128 wave tiles (16×16 each) │ -└─────────────────────────────────────────────────┘ -``` - -#### How Wave Tiles Combine (C = A × B): -``` -Matrix A Matrix B (stored transposed N×K) Matrix C -(256×32) (128×32) (256×128) - -Row of A tiles: Row of B tiles: One wave tile in C: -┌────┬────┐ ┌────┬────┐ ┌────┐ -│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16) -└────┴────┘ └────┴────┘ └────┘ - 16×16 each 16×16 each - -Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ - ↑ ↑ - K=0..15 K=16..31 - -Each wave tile in C is computed by: -- Taking one row of wave tiles from A (2 tiles along K) -- Taking one row of wave tiles from B (2 tiles along K) - Note: B is stored transposed (N×K), so a "row" in storage corresponds - to a "column" in the logical B^T matrix used in computation -- Performing dot product: Σ(A_k × B_k^T) for k=0,1 -``` - -**Key Insight:** -- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B -- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory -- This is the fundamental operation repeated across all 128 wave tiles in C -- Each warp computes one wave tile using MFMA instructions - ---- - -## Step 7: Create Shape, Problem, and Policy Structs - -```cpp -using PracticeGemmShape = ck_tile::PracticeGemmShape; -std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; - -using PracticeGemmHostProblem = ck_tile:: - PracticeGemmHostProblem; - -using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; -``` - -**What's happening:** - -### 1. **Shape Struct** -Encapsulates all tile shape information (BlockTile and WaveTile dimensions). - -### 2. **Problem Struct** -Holds complete problem description: -- Data types (ADataType, BDataType, CDataType, AccDataType) -- Shape information (BlockTile, WaveTile) - -In more complex examples, this would also include: -- Data layouts (row-major, column-major) -- Mathematical operations (e.g., transposed GEMM) - -### 3. **Policy Struct** -Describes data movement and thread-to-data mapping: -- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions -- In more complex kernels, includes: - - DRAM access patterns - - LDS (Local Data Share) usage strategies - - Thread distribution within blocks - -**CK Tile Design Pattern:** -``` -Kernel = Problem + Policy + Epilogue - ↑ ↑ ↑ - (What) (How) (Post-processing) -``` - ---- - -## Step 8: Calculate Grid and Block Dimensions - -```cpp -ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * - ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); - -std::cout << "kGridSize: " << kGridSize << std::endl; - -constexpr ck_tile::index_t kBlockSize = 256; -constexpr ck_tile::index_t kBlockPerCU = 1; -``` - -**What's happening:** - -### Grid Size Calculation -```cpp -kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N) - = ceil(512 / 256) × ceil(256 / 128) - = 2 × 2 - = 4 thread blocks -``` - -Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction). - -### Block Configuration -- `kBlockSize = 256`: Each thread block has 256 threads - - 256 threads / 64 threads per warp = **4 warps per block** -- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity) - -**Thread Hierarchy:** -``` -GPU -└── 1 Thread Block (Grid) - └── 256 Threads - ├── Warp 0 (threads 0-63) - ├── Warp 1 (threads 64-127) - ├── Warp 2 (threads 128-191) - └── Warp 3 (threads 192-255) -``` - ---- - -## Step 9: Create and Launch the Kernel - -```cpp -using gemm_kernel = - ck_tile::PracticeGemmKernel; - -float ave_time = ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, true, 0, 0, 1}, - ck_tile::make_kernel(gemm_kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(a_device.GetDeviceBuffer()), - static_cast(b_device.GetDeviceBuffer()), - static_cast(c_device.GetDeviceBuffer()), - M, - N, - K, - stride_a, - stride_b, - stride_c)); -``` - -**What's happening:** - -### 1. Kernel Composition -```cpp -using gemm_kernel = ck_tile::PracticeGemmKernel; -``` -The kernel is composed from Problem and Policy structs, following the CK Tile design pattern. - -### 2. Kernel Launch -`launch_kernel()` is a CK Tile utility that: -- Launches the GPU kernel using HIP runtime -- Measures execution time -- Returns average execution time in milliseconds - -### 3. Launch Parameters -- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled -- **Grid size**: `kGridSize = 1` - number of thread blocks -- **Block size**: `kBlockSize = 256` - threads per block -- **Shared memory**: `0` - no dynamic shared memory in this example -- **Kernel arguments**: Device pointers and problem dimensions - -### 4. Kernel Execution Flow -``` -launch_kernel() calls gemm_kernel.operator()() - ↓ -PracticeGemmKernel::operator() - ↓ -Creates tensor views over device memory - ↓ -Calls block-level pipeline - ↓ -Block pipeline calls warp-level pipeline - ↓ -Warp pipeline calls MFMA instructions - ↓ -Results written back to C matrix -``` - ---- - -## Step 10: Verify Results - -```cpp -auto pass = true; - -if(verification) -{ - // Reference gemm on CPU - ck_tile::HostTensor c_host_ref(c_lengths, c_strides); - reference_basic_gemm( - a_host, b_host, c_host_ref); - - // Copy GPU results back to host - ck_tile::HostTensor c_host_dev(c_lengths, c_strides); - c_device.FromDevice(c_host_dev.mData.data()); - - // Compare results - pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); - std::cout << "valid:" << (pass ? "y" : "n") << std::endl; -} -``` - -**What's happening:** - -### 1. CPU Reference Implementation -```cpp -reference_basic_gemm<...>(a_host, b_host, c_host_ref); -``` -Computes GEMM on CPU using a simple nested loop implementation (ground truth). - -### 2. Copy GPU Results to Host -```cpp -c_device.FromDevice(c_host_dev.mData.data()); -``` -Transfers the computed result from GPU memory back to CPU for comparison. - -### 3. Error Checking -```cpp -ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); -``` -Compares GPU and CPU results element-wise with tolerance: -- **Relative error**: 1e-3 (0.1%) -- **Absolute error**: 1e-3 - -**Verification Flow:** -``` -CPU GPU -┌─────────┐ ┌─────────┐ -│ a_host │ ────────> │a_device │ -│ b_host │ ────────> │b_device │ -└─────────┘ └─────────┘ - │ │ - ↓ ↓ -reference_gemm() GPU kernel - │ │ - ↓ ↓ -┌──────────┐ ┌──────────┐ -│c_host_ref│ │c_device │ -└──────────┘ └──────────┘ - │ │ - │ ↓ - │ FromDevice() - │ │ - ↓ ↓ - └────> check_err() <───┘ - │ - ↓ - Pass/Fail -``` - ---- - -## Complete Execution Flow Summary - -``` -1. Define data types (FP16 inputs, FP32 output) - ↓ -2. Set problem size (M=256, N=128, K=32) - ↓ -3. Create host tensors and initialize with random data - ↓ -4. Allocate device memory and transfer data (CPU → GPU) - ↓ -5. Configure hierarchical tiling (BlockTile, WaveTile) - ↓ -6. Create Shape, Problem, and Policy structs - ↓ -7. Calculate grid/block dimensions (1 block, 256 threads) - ↓ -8. Compose and launch kernel (Problem + Policy) - ↓ -9. Execute GEMM on GPU - │ ├─ Block-level pipeline - │ ├─ Warp-level pipeline - │ └─ MFMA instructions - ↓ -10. Verify results (compare GPU vs CPU reference) - ↓ -11. Calculate and print performance metrics - ↓ -12. Return success/failure -``` - ---- \ No newline at end of file diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg.hpp similarity index 100% rename from tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg.hpp diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg_policy.hpp similarity index 67% rename from tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg_policy.hpp index 188e481c65..258c12e4fe 100644 --- a/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg_policy.hpp @@ -6,6 +6,14 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +// Controls whether to use the A/B-swapped MFMA variant with transposed C register layout. +// 0 = WarpGemmMfmaF16F16F32M32N32K8 (standard, no swap, no transposed C) +// 1 = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution (swap A/B in MFMA + transposed C +// layout) +#ifndef CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION +#define CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION 1 +#endif + namespace ck_tile { // Default policy for BlockGemmASmemBSmemCReg @@ -15,24 +23,31 @@ struct BlockGemmASmemBSmemCRegPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - // NAIVE_IMPLEMENTATION uses 4x1 warp configuration constexpr index_t kMWarp = 4; constexpr index_t kNWarp = 1; - // NAIVE_IMPLEMENTATION uses mfma m32 n32 k8 + // mfma m32 n32 k8 if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION return make_tuple( WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, kMWarp, kNWarp); +#endif } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION return make_tuple( WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); +#else + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8{}, kMWarp, kNWarp); +#endif } else { diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp similarity index 100% rename from tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp similarity index 98% rename from tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp index 421a63649f..aae07331b7 100644 --- a/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp @@ -3,7 +3,7 @@ #pragma once -#include "../warp_level/block_gemm_asmem_bsmem_creg.hpp" +#include "block_gemm_asmem_bsmem_creg.hpp" #include "ck_tile/core.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/grid_gemm.hpp similarity index 100% rename from tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/grid_gemm.hpp diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp index 50a49d23fb..2c4137837f 100644 --- a/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp @@ -8,8 +8,8 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#include "block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp" -#include "host_level/grid_gemm.hpp" +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "grid_gemm.hpp" namespace ck_tile { diff --git a/tutorial/ck_tile/tile_distribution/CMakeLists.txt b/tutorial/ck_tile/tile_distribution/CMakeLists.txt new file mode 100644 index 0000000000..91947ac4fb --- /dev/null +++ b/tutorial/ck_tile/tile_distribution/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# These tutorials are hard-coded for CDNA (warp_size=64) with specific tile sizes. +# Only build for gfx942 (MI300X) and gfx950 (MI350X). +if(NOT (GPU_TARGETS MATCHES "gfx942|gfx950")) + message(VERBOSE "Skipping tile_distribution tutorials: requires gfx942 or gfx950") + return() +endif() + +foreach(i 1 2 3) + set(TUTORIAL_NAME "tile_tutorial_tile_distribution_${i}") + + add_executable(${TUTORIAL_NAME} EXCLUDE_FROM_ALL tile_distribution_${i}.cpp) + target_include_directories(${TUTORIAL_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_compile_options(${TUTORIAL_NAME} PRIVATE + -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported + ) + + add_dependencies(tutorials ${TUTORIAL_NAME}) +endforeach() diff --git a/tutorial/ck_tile/tile_distribution/README.md b/tutorial/ck_tile/tile_distribution/README.md new file mode 100644 index 0000000000..97ea643947 --- /dev/null +++ b/tutorial/ck_tile/tile_distribution/README.md @@ -0,0 +1,63 @@ +# CK Tile Distribution Encoding Tutorial + +## Overview + +Every `load_tile` and `store_tile` in CK needs to know **which thread reads which data element**. +This mapping is defined by a `tile_distribution_encoding` — a compile-time struct with 6 template +parameters: + +```cpp +tile_distribution_encoding +``` + +Every level of **Hs** (hierarchical dimensions) is assigned to exactly one role: + +| Role | Meaning | +|------|---------| +| **P** (parallel) | Thread ID selects which slice — different threads get different data | +| **Y** (yield) | Each thread owns the entire range in its buffer | +| **R** (replicate) | Identical data broadcast to multiple thread groups | + +## Tutorials + +These tutorials use the exact tile sizes from the naive GEMM tutorial +(`01_naive_gemm/`): MPerBlock=256, NPerBlock=128, KPerBlock=32, BlockSize=256, fp16. + +| # | File | Matrix | Tile | Key Concept | +|---|------|--------|------|-------------| +| 1 | `tile_distribution_1.cpp` | A (DRAM load) | 256×32 | NDimP=2, warp\_id→M1, lane\_id→M2×K0 (coalesced) | +| 2 | `tile_distribution_2.cpp` | B (DRAM load) | 128×32 | Same pattern as A, but N0=2 iterations (vs A's M0=4) due to smaller N | +| 3 | `tile_distribution_3.cpp` | C (registers) | 256×128 | Warp-level MFMA output + block-level composition, standard vs transposed | + +Tutorial 3 responds to `CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION` — rebuild with `=0` or `=1` +to see both C register layouts. + +**Architecture note:** All comments and concrete values assume **CDNA (warp_size=64)**. +On RDNA (warp_size=32), the thread-to-data mapping will differ. + +## Building + +```bash +cd /projects/composablekernel/build + +# Build all tutorials: +make tutorials -j +# or: ninja tutorials + +# Or build individually: +make tile_tutorial_tile_distribution_1 -j +make tile_tutorial_tile_distribution_2 -j +make tile_tutorial_tile_distribution_3 -j + +# Tutorial 3 with standard (non-transposed) C: +cmake -DCMAKE_CXX_FLAGS="-DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0" .. +make tile_tutorial_tile_distribution_3 -j +``` + +## Reference + +- Encoding definition: `include/ck_tile/core/tensor/tile_distribution_encoding.hpp` +- Thread identity (NDimP): `include/ck_tile/core/tensor/tile_distribution.hpp` +- MFMA warp output layout: `include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp` +- Production A/B distributions: `include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp` +- Naive GEMM tutorial: `tutorial/ck_tile/gemm/01_naive_gemm/` diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp new file mode 100644 index 0000000000..a764677c90 --- /dev/null +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp @@ -0,0 +1,285 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/* + * Tutorial: CK Tile Distribution Encoding — A Matrix DRAM Load + * + * Demonstrates how tile_distribution_encoding maps threads to A-matrix + * elements during a DRAM load in the naive GEMM tutorial. + * + * Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp + * MakeADramTileDistribution(), with fp16, BlockSize=256 + * + * Tile: M=256 × K=32 (matches the naive GEMM's A block tile) + * Threads: 256 (4 warps on CDNA, 8 on RDNA) + * + * Host initialises A with sequential values 0, 1, 2, ... (row-major). + * A[m][k] = m * K + k, so the printed value directly gives the linear index. + * GPU kernel loads A using the distribution, then prints per-thread buffer + * contents so the reader can verify which elements each thread received. + * + * Note: int32_t is used instead of fp16 for readable printf output. + * The distribution encoding is hardcoded to match the fp16 derivation + * (K1=16/sizeof(fp16)=8), not recomputed from sizeof(int32_t). + * + * No compute is performed — this is purely about data movement. + * + * Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32), + * the thread-to-data mapping will differ. + */ + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include + +using namespace ck_tile; + +// ============================================================================ +// THE GOAL +// ============================================================================ +// Matrix A: M=256 rows × K=32 columns, stored in DRAM (row-major, fp16). +// Load the entire tile into registers using 256 threads (4 warps on CDNA). +// +// For coalesced memory access with fp16, each lane loads 8 contiguous +// K-values (8 × 2 bytes = 16 bytes = 128 bits). Since K=32, we need +// 32/8 = 4 lanes to cover one row: +// +// lane 0: K=0..7 lane 1: K=8..15 lane 2: K=16..23 lane 3: K=24..31 +// └──────────────── one row of 32 K-columns ──────────────────────────────┘ +// +// With warp_size=64, each warp has 64 lanes. 4 lanes per row means +// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4×16 = 64 rows. +// To cover all 256 rows, each thread iterates M0 = 256/64 = 4 times. +// +// Per-thread buffer = 4 iterations × 8 K-values = 32 elements. +// +// Visually for warp 0 (lanes 0–63): +// +// A matrix (256×32) lane_id decomposition +// ──────────────── ────────────────────── +// row 0: [ K=0..7 | 8..15 | 16..23 | 24..31 ] +// L0 L1 L2 L3 ← iter 0 +// row 1: [ K=0..7 | 8..15 | 16..23 | 24..31 ] +// L4 L5 L6 L7 +// ... +// row 15: same pattern, lanes 60–63 +// ────── stride of 64 rows (4 warps × 16 rows/warp) ────── +// row 64: L0..L3 ← iter 1 +// ... +// row 128: L0..L3 ← iter 2 +// ... +// row 192: L0..L3 ← iter 3 +// +// ============================================================================ +// THE SOLUTION: tile_distribution_encoding +// ============================================================================ +// +// Production code derives (fp16, BlockSize=256, MPerBlock=256, KPerBlock=32): +// K1 = 16/sizeof(fp16) = 8 → vector load width (8 values) +// K0 = KPerBlock/K1 = 4 → 4 K-chunks per row +// M2 = warp_size/K0 = 16 → 16 rows per warp +// M1 = BlockSize/warp_size = 4 → 4 warps +// M0 = MPerBlock/(M2*M1) = 4 → 4 iterations +// +// Step 1 — Hierarchical dimensions (Hs): factor each axis. +// +// Hs[0] = sequence<4, 4, 16> → M = 4 × 4 × 16 = 256 +// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32 +// +// Hs[0] Hs[1] +// ┌─────┼─────┐ ┌───┴───┐ +// level 0 level 1 level 2 level 0 level 1 +// = 4 = 4 = 16 = 4 = 8 +// +// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). +// +// P0 = warp_id → Hs[0][1] = 4 (which warp → which M-group) +// P1 = lane_id → Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64) +// +// The merge transform decomposes lane_id: +// row_in_warp = lane_id / 4 (0..15, outer) +// k_chunk = lane_id % 4 (0..3, inner → coalesced!) +// +// Ps_major = tuple, sequence<1, 2>> +// Ps_minor = tuple, sequence<2, 0>> +// +// How to read Ps: the tuple has 2 elements → NDimP=2. +// First element = P0 = warp_id +// Second element = P1 = lane_id +// +// Ps_major = tuple< seq<1>, seq<1, 2> > +// ─P0(warp)─ ─P1(lane)── +// Ps_minor = tuple< seq<1>, seq<2, 0> > +// ─P0(warp)─ ─P1(lane)── +// +// P0: major=<1>, minor=<1> → Hs[0], level 1 → M1=4 +// P1: major=<1,2>, minor=<2,0> → merged: +// Hs[0] level 2 → M2=16 (outer, changes slowly) +// Hs[1] level 0 → K0=4 (inner, changes every lane → coalesced!) +// Total: 16 × 4 = 64 = warp_size +// lane / 4 → row_in_warp (M2), lane % 4 → K-chunk (K0) +// +// Step 3 — Yield dimensions (Ys): what each thread owns. +// +// Y0 = Hs[0][0] = 4 (M-iterations) +// Y1 = Hs[1][1] = 8 (vector load width) +// +// Ys_major = sequence<1, 2> +// Ys_minor = sequence<0, 1> +// +// How to read Ys: parallel arrays — position i gives Yi. +// +// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1 +// ─Y0─ ─Y1─ +// +// Y0: Hs[0] level 0 → M0=4 (iterations along M) +// Y1: Hs[1] level 1 → K1=8 (vector load width) +// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread. +// +// Step 4 — Replicate: Rs = sequence<1> (trivial, size 1). +// +// Complete tree: +// +// Hs[0] Hs[1] +// ┌─────┼─────┐ ┌───┴───┐ +// [Y0] [P0] [P1] [P1] [Y1] +// = 4 = 4 = 16 = 4 = 8 +// (iter) (warp) (row) (K-chunk) (vec load) +// +// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread. +// +// ============================================================================ + +static constexpr index_t kM = 256; +static constexpr index_t kK = 32; + +struct TileDistKernelA +{ + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const int32_t* p_data) const + { + static_assert(get_warp_size() == 64, + "This tutorial is hard-coded for CDNA (warp_size=64). " + "On RDNA (warp_size=32), the encoding values and print logic must change."); + + const auto a_tensor = make_naive_tensor_view( + p_data, make_tuple(kM, kK), make_tuple(kK, 1), number<1>{}, number<1>{}); + + constexpr auto distribution = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence<4, 8>>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + + auto window = make_tile_window( + a_tensor, make_tuple(number{}, number{}), {0, 0}, distribution); + + const auto tile = load_tile(window); + + const auto& buf = tile.get_thread_buffer(); + constexpr index_t warp_size = get_warp_size(); + constexpr index_t kBufSize = 32; // 4 iterations × 8 K-values + + int32_t local_buf[kBufSize]; + static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast(buf[i]); }); + + auto print_thread = [&](int tid) { + if(static_cast(threadIdx.x) == tid) + { + int lane = tid % static_cast(warp_size); + int warp = tid / static_cast(warp_size); + int row_in_wrp = lane / 4; + int k_chunk = lane % 4; + + printf("Thread %3d (warp %d, lane %2d) row_in_warp=%2d k_chunk=%d\n", + tid, + warp, + lane, + row_in_wrp, + k_chunk); + + for(int iter = 0; iter < 4; iter++) + { + int row = iter * 64 + warp * 16 + row_in_wrp; + int col = k_chunk * 8; + printf(" iter %d: A[%3d][%2d..%2d] =", iter, row, col, col + 7); + for(int k = 0; k < 8; k++) + printf(" %5d", local_buf[iter * 8 + k]); + printf("\n"); + } + } + }; + + if(blockIdx.x == 0) + { + if(threadIdx.x == 0) + { + printf("\n=== Tile Distribution: A-Matrix DRAM Load ===\n"); + printf("Source: MakeADramTileDistribution (fp16, BlockSize=256)\n"); + printf("Tile: %dx%d BlockSize: %d WarpSize: %d Warps: %d\n", + kM, + kK, + kBlockSize, + static_cast(warp_size), + kBlockSize / static_cast(warp_size)); + printf("Each thread: 4 iterations x 8 K-values = 32 elements\n\n"); + printf("Coalescing: lanes 0-3 read K=0..31 of the same row\n"); + printf(" (4 x 8 = 32 K-values = one full row)\n\n"); + } + __syncthreads(); + + // Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64, 128, 192}, K=0..7 + print_thread(0); + __syncthreads(); + // Lane 1: k_chunk=1 → same rows, K=8..15 (coalesced with lane 0) + print_thread(1); + __syncthreads(); + // Lane 4: row_in_warp=1 → rows {1, 65, 129, 193}, K=0..7 + print_thread(4); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 1 ---\n"); + __syncthreads(); + // Warp 1, Lane 0: rows {16, 80, 144, 208}, K=0..7 + print_thread(static_cast(warp_size)); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 3 (last) ---\n"); + __syncthreads(); + // Warp 3, Lane 63: rows {63, 127, 191, 255}, K=24..31 + print_thread(kBlockSize - 1); + __syncthreads(); + } + } +}; + +int main() +{ + printf("=== CK Tile Distribution Tutorial 1: A-Matrix DRAM Load ===\n"); + printf("=== Matches naive GEMM: MPerBlock=256, KPerBlock=32 ===\n\n"); + + HostTensor h_tensor({kM, kK}); + for(int i = 0; i < kM * kK; i++) + h_tensor.mData[i] = i; + + printf("Host matrix A[%d x %d], row-major, A[m][k] = m*%d + k\n\n", kM, kK, kK); + + DeviceMem d_data(h_tensor); + + launch_kernel(stream_config{}, + make_kernel<1>(TileDistKernelA{}, + dim3(1), + dim3(TileDistKernelA::kBlockSize), + 0, + static_cast(d_data.GetDeviceBuffer()))); + hip_check_error(hipDeviceSynchronize()); + + printf("Done.\n"); + return 0; +} diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp new file mode 100644 index 0000000000..5d5ae3227f --- /dev/null +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/* + * Tutorial: CK Tile Distribution Encoding — B Matrix DRAM Load + * + * Demonstrates how tile_distribution_encoding maps threads to B-matrix + * elements during a DRAM load in the naive GEMM tutorial. + * + * Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp + * MakeBDramTileDistribution(), with fp16, BlockSize=256 + * + * Tile: N=128 × K=32 (matches the naive GEMM's B block tile) + * Threads: 256 (4 warps on CDNA, 8 on RDNA) + * + * The B encoding has the SAME structure as the A encoding (Tutorial 1), + * but with N=128 instead of M=256. This changes only N0 (the iteration + * count), showing how the same encoding pattern adapts to different + * tile sizes. + * + * No compute is performed — this is purely about data movement. + * + * Note: int32_t is used instead of fp16 for readable printf output. + * The distribution encoding is hardcoded to match the fp16 derivation. + * + * Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32), + * the thread-to-data mapping will differ. + */ + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include + +using namespace ck_tile; + +// ============================================================================ +// THE GOAL +// ============================================================================ +// Matrix B: N=128 rows × K=32 columns, stored in DRAM (row-major, fp16). +// (In GEMM, B is stored as [N, K] — each "row" is one output channel.) +// Load the entire tile into registers using 256 threads (4 warps on CDNA). +// +// Same coalescing strategy as the A-matrix (Tutorial 1): +// - 4 lanes cover one K-row (4 × 8 = 32 K-values) +// - Each warp (64 lanes) covers 16 N-rows +// - 4 warps cover 64 N-rows per iteration +// - N0 = 128/64 = 2 iterations (vs 4 for A's M=256) +// +// Per-thread buffer = 2 iterations × 8 K-values = 16 elements. +// +// Compare with Tutorial 1 (A-matrix): +// A: M=256, M0=4, buffer=32 | B: N=128, N0=2, buffer=16 +// Everything else is identical — same K-splitting, same coalescing. +// +// ============================================================================ +// THE SOLUTION: tile_distribution_encoding +// ============================================================================ +// +// Production code derives (fp16, BlockSize=256, NPerBlock=128, KPerBlock=32): +// K1 = 16/sizeof(fp16) = 8 +// K0 = KPerBlock/K1 = 4 +// N2 = warp_size/K0 = 16 +// N1 = BlockSize/warp_size = 4 +// N0 = NPerBlock/(N2*N1) = 2 +// +// Step 1 — Hierarchical dimensions (Hs): +// +// Hs[0] = sequence<2, 4, 16> → N = 2 × 4 × 16 = 128 +// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32 +// +// Hs[0] Hs[1] +// ┌─────┼─────┐ ┌───┴───┐ +// [Y0] [P0] [P1] [P1] [Y1] +// = 2 = 4 = 16 = 4 = 8 +// (iter) (warp) (row) (K-chunk) (vec load) +// +// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). +// +// Ps_major = tuple, sequence<1, 2>> +// Ps_minor = tuple, sequence<2, 0>> +// +// How to read Ps: the tuple has 2 elements → NDimP=2. +// First element = P0 = warp_id +// Second element = P1 = lane_id +// +// P0: major=<1>, minor=<1> → Hs[0], level 1 → N1=4 (which warp) +// P1: major=<1,2>, minor=<2,0> → merged: +// Hs[0] level 2 → N2=16 (outer, row within warp) +// Hs[1] level 0 → K0=4 (inner, K-chunk → coalesced!) +// lane / 4 → row_in_warp, lane % 4 → K-chunk +// +// Step 3 — Yield dimensions (Ys): what each thread owns. +// +// Ys_major = sequence<1, 2> +// Ys_minor = sequence<0, 1> +// +// How to read Ys: parallel arrays — position i gives Yi. +// +// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1 +// ─Y0─ ─Y1─ +// +// Y0: Hs[0] level 0 → N0=2 (iterations along N) +// Y1: Hs[1] level 1 → K1=8 (vector load width) +// +// Buffer size = Y0 × Y1 = 2 × 8 = 16 elements per thread. +// +// ============================================================================ + +static constexpr index_t kN = 128; +static constexpr index_t kK = 32; + +struct TileDistKernelB +{ + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const int32_t* p_data) const + { + static_assert(get_warp_size() == 64, + "This tutorial is hard-coded for CDNA (warp_size=64). " + "On RDNA (warp_size=32), the encoding values and print logic must change."); + + const auto b_tensor = make_naive_tensor_view( + p_data, make_tuple(kN, kK), make_tuple(kK, 1), number<1>{}, number<1>{}); + + constexpr auto distribution = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence<4, 8>>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + + auto window = make_tile_window( + b_tensor, make_tuple(number{}, number{}), {0, 0}, distribution); + + const auto tile = load_tile(window); + + const auto& buf = tile.get_thread_buffer(); + constexpr index_t warp_size = get_warp_size(); + constexpr index_t kBufSize = 16; // 2 iterations × 8 K-values + + int32_t local_buf[kBufSize]; + static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast(buf[i]); }); + + auto print_thread = [&](int tid) { + if(static_cast(threadIdx.x) == tid) + { + int lane = tid % static_cast(warp_size); + int warp = tid / static_cast(warp_size); + int row_in_wrp = lane / 4; + int k_chunk = lane % 4; + + printf("Thread %3d (warp %d, lane %2d) row_in_warp=%2d k_chunk=%d\n", + tid, + warp, + lane, + row_in_wrp, + k_chunk); + + for(int iter = 0; iter < 2; iter++) + { + int row = iter * 64 + warp * 16 + row_in_wrp; + int col = k_chunk * 8; + printf(" iter %d: B[%3d][%2d..%2d] =", iter, row, col, col + 7); + for(int k = 0; k < 8; k++) + printf(" %4d", local_buf[iter * 8 + k]); + printf("\n"); + } + } + }; + + if(blockIdx.x == 0) + { + if(threadIdx.x == 0) + { + printf("\n=== Tile Distribution: B-Matrix DRAM Load ===\n"); + printf("Source: MakeBDramTileDistribution (fp16, BlockSize=256)\n"); + printf("Tile: %dx%d BlockSize: %d WarpSize: %d Warps: %d\n", + kN, + kK, + kBlockSize, + static_cast(warp_size), + kBlockSize / static_cast(warp_size)); + printf("Each thread: 2 iterations x 8 K-values = 16 elements\n"); + printf("Compare with Tutorial 1 (A): same K-split, but N0=2 vs M0=4\n\n"); + } + __syncthreads(); + + // Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64}, K=0..7 + print_thread(0); + __syncthreads(); + // Lane 1: k_chunk=1 → same rows, K=8..15 + print_thread(1); + __syncthreads(); + // Lane 4: row_in_warp=1 → rows {1, 65}, K=0..7 + print_thread(4); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 1 ---\n"); + __syncthreads(); + // Warp 1, Lane 0: rows {16, 80}, K=0..7 + print_thread(static_cast(warp_size)); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 3 (last) ---\n"); + __syncthreads(); + // Warp 3, Lane 63: rows {63, 127}, K=24..31 + print_thread(kBlockSize - 1); + __syncthreads(); + } + } +}; + +int main() +{ + printf("=== CK Tile Distribution Tutorial 2: B-Matrix DRAM Load ===\n"); + printf("=== Matches naive GEMM: NPerBlock=128, KPerBlock=32 ===\n\n"); + + HostTensor h_tensor({kN, kK}); + for(int i = 0; i < kN * kK; i++) + h_tensor.mData[i] = i; + + printf("Host matrix B[%d x %d], row-major, B[n][k] = n*%d + k\n\n", kN, kK, kK); + + DeviceMem d_data(h_tensor); + + launch_kernel(stream_config{}, + make_kernel<1>(TileDistKernelB{}, + dim3(1), + dim3(TileDistKernelB::kBlockSize), + 0, + static_cast(d_data.GetDeviceBuffer()))); + hip_check_error(hipDeviceSynchronize()); + + printf("Done.\n"); + return 0; +} diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp new file mode 100644 index 0000000000..4a782b592b --- /dev/null +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp @@ -0,0 +1,376 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/* + * Tutorial: CK Tile Distribution Encoding — C Matrix Register Layout + * + * Demonstrates how C-matrix elements are distributed across thread registers + * after MFMA computation. Unlike A/B (which are DRAM loads), C lives entirely + * in registers — the distribution describes which thread holds which output + * element of C = A × B. + * + * This tutorial shows BOTH: + * 1. The warp-level C distribution (from MFMA m32n32k8 output mapping) + * 2. The block-level outer distribution (how multiple warps tile C) + * 3. The composed distribution (what CK actually uses) + * + * The macro CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION (default 1) selects + * between the standard and transposed C register layouts. + * + * Tile: M=256 × N=128 (matches the naive GEMM's C block tile) + * Warp config: MWarp=4, NWarp=1 + * MFMA: m32n32k8 (each warp produces a 32×32 output) + * + * No actual MFMA compute — we construct a C distributed tensor, fill it + * with marker values (thread_id * 1000 + buffer_index), and print per-thread + * contents to reveal which buffer slots belong to which thread. + * + * Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32), + * the thread-to-data mapping will differ. + */ + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include + +using namespace ck_tile; + +// Controls which C register layout to demonstrate +#ifndef CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION +#define CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION 1 +#endif + +// ============================================================================ +// THE GOAL +// ============================================================================ +// After GEMM computation, each thread holds a subset of the C matrix +// (M=256 × N=128 = 32768 elements) in its registers. We want to understand +// exactly which C[m][n] elements each thread owns. +// +// The mapping has two levels: +// +// BLOCK LEVEL (256×128 → warps and iterations): +// - 4 warps along M (MWarp=4), 1 warp along N (NWarp=1) +// - Each warp covers 32 M-rows × 128 N-cols of the block tile +// - Within each warp: MIterPerWarp=2, NIterPerWarp=4 +// → 2 × 4 = 8 warp-tile iterations per warp +// - Each warp-tile iteration is a 32×32 MFMA output +// +// WARP LEVEL (32×32 → threads): +// - 64 threads produce 32 × 32 = 1024 C elements +// - Each thread holds 1024/64 = 16 elements +// - MFMA m32n32k8 arranges these 16 elements in a specific pattern +// +// The per-thread register buffer = 8 iterations × 16 elements = 128 floats. +// +// ============================================================================ +// THE SOLUTION: Two-Level Distribution +// ============================================================================ +// +// --- WARP-LEVEL C DISTRIBUTION (from MFMA m32n32k8) --- +// +// For fp16→fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2, +// kCM1PerLane=4, kCNLane=32): +// +// STANDARD (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0): +// +// Hs[0] = sequence<4, 2, 4> → M-dim: 4 × 2 × 4 = 32 +// Hs[1] = sequence<32> → N-dim: 32 +// Ps_major = tuple> → lane maps to Hs[0][1] and Hs[1][0] +// Ps_minor = tuple> +// +// How to read Ps: the tuple has 1 element → NDimP=1 → P0 = lane_id. +// P0: major=<1,2>, minor=<1,0> → merged: +// Hs[0] level 1 → kCMLane=2 (outer, M-half) +// Hs[1] level 0 → kCNLane=32 (inner, N-col → contiguous!) +// lane / 32 → M-half, lane % 32 → N-col +// +// Ys_major = sequence<1, 1> +// Ys_minor = sequence<0, 2> +// +// How to read Ys: parallel arrays — position i gives Yi. +// +// Ys_major = seq< 1, 1 > → Y0 is in Hs[0], Y1 is in Hs[0] +// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2 +// ─Y0─ ─Y1─ +// +// Y0: Hs[0] level 0 → kCM0PerLane=4 (M outer per lane) +// Y1: Hs[0] level 2 → kCM1PerLane=4 (M inner per lane) +// +// Hs[0] Hs[1] +// ┌─────┼─────┐ │ +// [Y0] [P0] [Y1] [P0] +// = 4 = 2 = 4 = 32 +// (M outer)(lane) (M inner) (lane → N) +// +// Per-thread: Y0 × Y1 = 4 × 4 = 16 elements per warp-tile. +// Lane decomposition: lane / 32 → M-half (0..1), lane % 32 → N-col (0..31) +// +// TRANSPOSED (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1): +// +// Hs[0] = sequence<32> → First dim: N (swapped!) +// Hs[1] = sequence<4, 2, 4> → Second dim: M (swapped!) +// Ps_major = tuple> → lane maps to Hs[1][1] and Hs[0][0] +// Ps_minor = tuple> +// +// How to read Ps: tuple has 1 element → NDimP=1 → P0 = lane_id. +// P0: major=<2,1>, minor=<1,0> → merged: +// Hs[1] level 1 → kCMLane=2 (outer, M-half) +// Hs[0] level 0 → kCNLane=32 (inner, N-col → contiguous!) +// Same lane decomposition as standard, but dimensions are swapped. +// +// Ys_major = sequence<2, 2> +// Ys_minor = sequence<0, 2> +// +// How to read Ys: +// Ys_major = seq< 2, 2 > → Y0 is in Hs[1], Y1 is in Hs[1] +// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2 +// ─Y0─ ─Y1─ +// +// Y0: Hs[1] level 0 → kCM0PerLane=4 (M outer per lane) +// Y1: Hs[1] level 2 → kCM1PerLane=4 (M inner per lane) +// Same 16 elements, but now both Y dims are in Hs[1] (M is second). +// +// Hs[0] Hs[1] +// │ ┌─────┼─────┐ +// [P0] [Y0] [P0] [Y1] +// = 32 = 4 = 2 = 4 +// (lane → N) (M outer)(lane)(M inner) +// +// Same 16 elements per thread, but N is the first dimension in the +// distribution — this changes which elements are contiguous in the +// thread buffer, affecting downstream store coalescing. +// +// --- BLOCK-LEVEL OUTER DISTRIBUTION --- +// +// MIterPerWarp = MPerBlock / (MWarp × WarpGemm::kM) = 256 / (4 × 32) = 2 +// NIterPerWarp = NPerBlock / (NWarp × WarpGemm::kN) = 128 / (1 × 32) = 4 +// +// Hs[0] = sequence<2, 4> → M-dim: 2 iters × 4 warps +// Hs[1] = sequence<4, 1> → N-dim: 4 iters × 1 warp +// Ps_major = tuple> +// Ps_minor = tuple> +// +// How to read Ps: tuple has 1 element → NDimP=1 → P0 = warp_id. +// P0: major=<1,2>, minor=<1,1> → merged: +// Hs[0] level 1 → MWarp=4 (outer) +// Hs[1] level 1 → NWarp=1 (inner, trivial) +// Total: 4 × 1 = 4 = number of warps +// +// Ys_major = sequence<1, 2> +// Ys_minor = sequence<0, 0> +// +// How to read Ys: +// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 0 > → Y0 is level 0, Y1 is level 0 +// ─Y0─ ─Y1─ +// +// Y0: Hs[0] level 0 → MIterPerWarp=2 +// Y1: Hs[1] level 0 → NIterPerWarp=4 +// Block-level buffer = Y0 × Y1 = 2 × 4 = 8 warp-tile slots. +// +// tile_distribution_encoding, +// tuple, sequence<4, 1>>, +// tuple>, tuple>, +// sequence<1, 2>, sequence<0, 0>> +// +// --- COMPOSED (what CK uses) --- +// +// make_embed_tile_distribution_encoding(block_outer, warp_encoding) +// embeds the warp encoding inside each (MIter, MWarp, NIter, NWarp) cell. +// Total per-thread buffer = 2 × 4 × 16 = 128 elements. +// +// ============================================================================ + +static constexpr index_t kM = 256; +static constexpr index_t kN = 128; + +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION +using WarpGemm = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; +#else +using WarpGemm = WarpGemmMfmaF16F16F32M32N32K8; +#endif + +static constexpr index_t kMWarp = 4; +static constexpr index_t kNWarp = 1; + +static constexpr index_t kMIterPerWarp = kM / (kMWarp * WarpGemm::kM); // 2 +static constexpr index_t kNIterPerWarp = kN / (kNWarp * WarpGemm::kN); // 4 + +struct TileDistKernelC +{ + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()() const + { + static_assert(get_warp_size() == 64, + "This tutorial is hard-coded for CDNA (warp_size=64). " + "On RDNA (warp_size=32), the encoding values and print logic must change."); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + constexpr index_t kBufSize = c_block_tensor.get_thread_buffer_size(); + + // Fill each thread's buffer with a marker value: + // We can't easily set C[m][n] = m*N + n without knowing the inverse mapping, + // so instead we fill with thread_id * 1000 + buffer_index to identify ownership. + static_for<0, kBufSize, 1>{}([&](auto i) { + c_block_tensor.get_thread_buffer()(i) = + static_cast(threadIdx.x * 1000 + static_cast(i)); + }); + + constexpr index_t warp_size = get_warp_size(); + + // Copy compile-time-indexed buffer into a plain array for runtime printing + float local_buf[kBufSize]; + static_for<0, kBufSize, 1>{}( + [&](auto i) { local_buf[i] = c_block_tensor.get_thread_buffer()[i]; }); + + auto print_thread = [&](int tid) { + if(static_cast(threadIdx.x) == tid) + { + int lane = tid % static_cast(warp_size); + int warp = tid / static_cast(warp_size); + + printf("Thread %3d (warp %d, lane %2d) buf_size=%d\n", + tid, + warp, + lane, + static_cast(kBufSize)); + +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION + printf(" Layout: TRANSPOSED (N is first dimension)\n"); +#else + printf(" Layout: STANDARD (M is first dimension)\n"); +#endif + + printf(" Block-level: MIterPerWarp=%d, NIterPerWarp=%d\n", + static_cast(kMIterPerWarp), + static_cast(kNIterPerWarp)); + printf(" Warp-level: 16 elements per warp-tile (32x32 MFMA output)\n"); + printf(" Total: %d x %d x 16 = %d elements\n", + static_cast(kMIterPerWarp), + static_cast(kNIterPerWarp), + static_cast(kBufSize)); + + constexpr int kPerWarpTile = 16; + for(int mIter = 0; mIter < static_cast(kMIterPerWarp); mIter++) + { + for(int nIter = 0; nIter < static_cast(kNIterPerWarp); nIter++) + { + int base = (mIter * static_cast(kNIterPerWarp) + nIter) * kPerWarpTile; + printf(" [mIter=%d, nIter=%d] buf[%3d..%3d]:", + mIter, + nIter, + base, + base + kPerWarpTile - 1); + for(int k = 0; k < kPerWarpTile; k++) + { + printf(" %.0f", static_cast(local_buf[base + k])); + } + printf("\n"); + } + } + } + }; + + if(blockIdx.x == 0) + { + if(threadIdx.x == 0) + { + printf("\n=== Tile Distribution: C-Matrix Register Layout ===\n"); + printf("Tile: %dx%d BlockSize: %d WarpSize: %d\n", + static_cast(kM), + static_cast(kN), + static_cast(kBlockSize), + static_cast(warp_size)); + printf("MWarp=%d, NWarp=%d, MFMA=m32n32k8\n", + static_cast(kMWarp), + static_cast(kNWarp)); + printf("MIterPerWarp=%d, NIterPerWarp=%d\n", + static_cast(kMIterPerWarp), + static_cast(kNIterPerWarp)); +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION + printf("Mode: TRANSPOSED C (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1)\n"); + printf(" WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution\n"); + printf(" Warp encoding: , tuple, seq<4,2,4>>,\n"); + printf(" tuple>, tuple>,\n"); + printf(" seq<2,2>, seq<0,2>>\n"); +#else + printf("Mode: STANDARD C (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0)\n"); + printf(" WarpGemmMfmaF16F16F32M32N32K8\n"); + printf(" Warp encoding: , tuple, seq<32>>,\n"); + printf(" tuple>, tuple>,\n"); + printf(" seq<1,1>, seq<0,2>>\n"); +#endif + printf("\nBlock outer: , tuple, seq<%d,%d>>,\n", + static_cast(kMIterPerWarp), + static_cast(kMWarp), + static_cast(kNIterPerWarp), + static_cast(kNWarp)); + printf(" tuple>, tuple>,\n"); + printf(" seq<1,2>, seq<0,0>>\n\n"); + } + __syncthreads(); + + // Warp 0, Lane 0 + print_thread(0); + __syncthreads(); + // Warp 0, Lane 32 (different M-half in standard, different N in transposed) + print_thread(32); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 1 (covers different M-rows than warp 0) ---\n"); + __syncthreads(); + print_thread(static_cast(warp_size)); + __syncthreads(); + + if(threadIdx.x == 0) + printf("\n--- Warp 3 (last) ---\n"); + __syncthreads(); + print_thread(kBlockSize - 1); + __syncthreads(); + } + } +}; + +int main() +{ + printf("=== CK Tile Distribution Tutorial 3: C-Matrix Register Layout ===\n"); + printf("=== Matches naive GEMM: MPerBlock=256, NPerBlock=128 ===\n\n"); + printf("MFMA m32n32k8: each warp produces 32x32 = 1024 elements\n"); + printf(" 64 threads per warp → 16 elements per thread per warp-tile\n"); + printf(" MWarp=4, NWarp=1 → 4 warps along M, 1 along N\n"); + printf(" MIterPerWarp=2, NIterPerWarp=4 → 8 warp-tiles per warp\n"); + printf(" Total per thread: 8 × 16 = 128 elements\n\n"); + +#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION + printf("Current mode: TRANSPOSED C distribution\n"); + printf(" Rebuild with -DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0 for standard\n\n"); +#else + printf("Current mode: STANDARD C distribution\n"); + printf(" Rebuild with -DCK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1 for transposed\n\n"); +#endif + + launch_kernel(stream_config{}, + make_kernel<1>(TileDistKernelC{}, dim3(1), dim3(TileDistKernelC::kBlockSize), 0)); + hip_check_error(hipDeviceSynchronize()); + + printf("Done.\n"); + return 0; +}