From efcd6297d4aca927b58f45c1fbdf40e16f6de322 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 15:15:49 -0500 Subject: [PATCH] Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038) * feat: add tutorial folder with gemm tutorial * chore: move copy kernel from examples folder to tutorial * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: remove handdrawn images * docs: add write ups to explain the gemm kernel * docs: add about block level pipeline and static distributed tensors --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: b145a5fe80d2f9d965f2c8555808017c3a660fc2] --- CMakeLists.txt | 6 + example/ck_tile/CMakeLists.txt | 1 - tutorial/CMakeLists.txt | 15 + .../ck_tile/00_copy_kernel}/CMakeLists.txt | 6 +- .../ck_tile/00_copy_kernel}/README.md | 0 .../ck_tile/00_copy_kernel}/copy_basic.cpp | 22 +- .../ck_tile/00_copy_kernel}/copy_basic.hpp | 0 .../00_copy_kernel}/test_tile_example.sh | 2 +- .../01_naive_gemm/BLOCK_LEVEL_PIPELINE.md | 589 +++++++++++++++++ tutorial/ck_tile/01_naive_gemm/CMakeLists.txt | 7 + .../01_naive_gemm/HOST_LEVEL_PIPELINE.md | 618 ++++++++++++++++++ .../01_naive_gemm/KERNEL_ENTRY_POINT.md | 464 +++++++++++++ tutorial/ck_tile/01_naive_gemm/README.md | 150 +++++ tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md | 506 ++++++++++++++ ...e_gemm_block_pipeline_agmem_bgmem_creg.hpp | 165 +++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 135 ++++ ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 92 +++ ...tice_gemm_host_policy_agmem_bgmem_creg.hpp | 51 ++ .../ck_tile/01_naive_gemm/practice_gemm.cpp | 131 ++++ .../ck_tile/01_naive_gemm/practice_gemm.hpp | 69 ++ .../ck_tile/01_naive_gemm/reference_gemm.hpp | 36 + ...ce_gemm_warp_pipeline_asmem_bsmem_creg.hpp | 195 ++++++ ...tice_gemm_warp_policy_asmem_bsmem_creg.hpp | 35 + tutorial/ck_tile/CMakeLists.txt | 7 + 24 files changed, 3287 insertions(+), 15 deletions(-) create mode 100644 tutorial/CMakeLists.txt rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/CMakeLists.txt (54%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/README.md (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.cpp (86%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.hpp (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/test_tile_example.sh (95%) create mode 100644 tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/CMakeLists.txt create mode 100644 tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md create mode 100644 tutorial/ck_tile/01_naive_gemm/README.md create mode 100644 tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 049da5637f..7b4990dba4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -683,6 +683,12 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) + + add_subdirectory(tutorial) + rocm_package_setup_component(tutorials + LIBRARY_NAME composablekernel + PACKAGE_NAME tutorials + ) add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index a6cfcde86e..92ee0a4c31 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -25,7 +25,6 @@ add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) -add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..a2f35ca53f --- /dev/null +++ b/tutorial/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + +message(STATUS "Building tutorials...") +add_custom_target(tutorials) + +# add all tutorial subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/example/ck_tile/39_copy/CMakeLists.txt b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt similarity index 54% rename from example/ck_tile/39_copy/CMakeLists.txt rename to tutorial/ck_tile/00_copy_kernel/CMakeLists.txt index 98397a33d2..91dd036eff 100644 --- a/example/ck_tile/39_copy/CMakeLists.txt +++ b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt @@ -1,7 +1,9 @@ -add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp) +add_executable(tile_tutorial_copy_kernel EXCLUDE_FROM_ALL copy_basic.cpp) # Impact: This flag ensures that the compiler doesn't make # assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. -target_compile_options(tile_example_copy PRIVATE +target_compile_options(tile_tutorial_copy_kernel PRIVATE -mllvm -enable-noalias-to-md-conversion=0 ) + +add_dependencies(tutorials tile_tutorial_copy_kernel) diff --git a/example/ck_tile/39_copy/README.md b/tutorial/ck_tile/00_copy_kernel/README.md similarity index 100% rename from example/ck_tile/39_copy/README.md rename to tutorial/ck_tile/00_copy_kernel/README.md diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp similarity index 86% rename from example/ck_tile/39_copy/copy_basic.cpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.cpp index de91dc1be9..282e9ff8c1 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp @@ -54,10 +54,10 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); // Define tile configuration - using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N - using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension - using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension - using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension + using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N + using WaveTile = ck_tile::sequence<64, 4>; // per-wave tile size along M and N dimension + using BlockWaves = ck_tile::sequence<4, 1>; // number of waves per block along M and N dimension + using BlockTile = ck_tile::sequence<512, 4>; // per-block tile size along M and N dimension // Calculate grid size ck_tile::index_t kGridSize = @@ -68,14 +68,14 @@ bool run(const ck_tile::ArgParser& arg_parser) using Shape = ck_tile::TileCopyShape; using Problem = ck_tile::TileCopyProblem; using Policy = ck_tile::TileCopyPolicy; - using Kernel = ck_tile::ElementWiseTileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel_LDS; + using Kernel = ck_tile::ElementWiseTileCopyKernel; // operates on element by + // element basis. - // question: Why do we not have a pipeline? - // answer: For basic copy operation, pipeline is not needed. - // we intentionally do not use pipeline for this example and let the kernel be composite of - // Problem and Policy + // We also implement two variations of the copy kernel: + // 1. TileCopyKernel: This is the basic copy kernel that operates on tile by tile basis. + // 2. TileCopyKernel_LDS: This is the copy kernel that operates on tile by tile basis and uses + // the LDS. using Kernel = ck_tile::TileCopyKernel; using Kernel = + // ck_tile::TileCopyKernel_LDS; auto blockSize = Kernel::BlockSize(); diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.hpp similarity index 100% rename from example/ck_tile/39_copy/copy_basic.hpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.hpp diff --git a/example/ck_tile/39_copy/test_tile_example.sh b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh similarity index 95% rename from example/ck_tile/39_copy/test_tile_example.sh rename to tutorial/ck_tile/00_copy_kernel/test_tile_example.sh index 416338fac4..4ee5fdf15d 100755 --- a/example/ck_tile/39_copy/test_tile_example.sh +++ b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh @@ -4,7 +4,7 @@ set -euo pipefail -BIN="${BIN:-../../../build/bin/tile_example_copy}" +BIN="${BIN:-../../../build/bin/tile_tutorial_copy_kernel}" WARMUP="${WARMUP:-20}" REPEAT="${REPEAT:-100}" VALIDATE="${VALIDATE:-1}" diff --git a/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..114fccfd56 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md @@ -0,0 +1,589 @@ +# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg + +## Overview + +The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates: +1. **Data movement** from DRAM → Registers → LDS +2. **GEMM computation** using data in LDS +3. **Iteration** over the K dimension when needed + +This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C. + +--- + +## Architecture: Problem and Policy + +Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern: + +### Problem: `PracticeGemmBlockPipelineProblem` +Contains: +- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType` +- **Shape information**: `BlockTile` and `WaveTile` dimensions + +### Policy: `PracticeGemmBlockPolicy` +Contains strategies for: +1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`) + - Defines how 256 threads in a block map to elements of a block tile + - Each thread knows which elements to load/store from DRAM to its registers + - We'll cover tile distribution construction in detail later + +2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`) + - Describes how data is logically organized in Local Data Share (LDS) + - Optimizes for bank conflict avoidance and efficient access patterns + - We'll cover LDS descriptor construction in detail later + +3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`) + - Returns the warp-level GEMM implementation + +--- + +## Inputs and Outputs + +```cpp +template +CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const +``` + +### Inputs: +- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock) +- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock) +- `num_loop`: Number of iterations along K dimension +- `p_smem`: Pointer to shared memory (LDS) + +### Output: +- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs) + +--- + +## Step-by-Step Walkthrough + +### Step 1: Create LDS Tensor Views + +```cpp +// A tile in LDS +ADataType* p_a_lds = static_cast(p_smem); +constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); +auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + +// B tile in LDS (placed after A in shared memory) +BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); +constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); +auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); +``` + +**What's happening:** +- We partition the shared memory (`p_smem`) into two regions: one for A, one for B +- We create **tensor views** over these LDS regions using descriptors from the policy +- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory + +**Memory Layout:** +``` +Shared Memory (LDS): +┌─────────────────────┬─────────────────────┐ +│ A Block Tile │ B Block Tile │ +│ (256×32 fp16) │ (128×32 fp16) │ +└─────────────────────┴─────────────────────┘ +↑ ↑ +p_a_lds p_b_lds +``` + +--- + +### Step 2: Create Tile Windows for Data Movement + +We create **6 tile windows** for different purposes: + +#### 2a. DRAM → Registers (Load from DRAM) + +```cpp +auto a_copy_dram_window = make_tile_window( + a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), // 256×32 + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // ← Tile distribution! +``` + +**Key Points:** +- `a_copy_dram_window` is a `tile_window_with_static_distribution` +- The **tile distribution** tells each thread which elements to load from DRAM +- This window will **slide along the K dimension** in the loop + +#### 2b. Registers → LDS (Store to LDS) + +```cpp +auto a_copy_lds_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), // 256×32 + {0, 0}, // Origin at (0, 0) in LDS + a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM! +``` + +**Key Points:** +- Uses the **same tile distribution** as `a_copy_dram_window` +- This ensures each thread stores to LDS in the same pattern it loaded from DRAM +- Origin is always `{0, 0}` because LDS is reused for each K iteration + +#### 2c. LDS → Registers (GEMM Input) + +```cpp +auto a_lds_gemm_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}); // No tile distribution! +``` + +**Key Points:** +- This is a `tile_window_with_static_lengths` (no explicit distribution) +- Used as input to the warp-level GEMM +- The warp GEMM will handle its own thread mapping internally + +**Similar windows are created for B:** +- `b_copy_dram_window`: Load B from DRAM +- `b_copy_lds_window`: Store B to LDS +- `b_lds_gemm_window`: Read B from LDS for GEMM + +--- + +### Step 3: Create Distributed Tensors (VGPRs) + +```cpp +using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); +using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + +ABlockTile a_block_tile; // Per-thread registers for A +BBlockTile b_block_tile; // Per-thread registers for B +``` + +#### What is `make_static_distributed_tensor`? + +**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**. + +**Key Properties:** +1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers +2. **Compile-time sized**: Buffer size determined by tile distribution at compile time +3. **Zero-overhead**: All indexing and layout transformations happen at compile time + +**How it works:** + +```cpp +template +struct static_distributed_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + // Calculate per-thread storage size from tile distribution + using ThreadTensorDesc = + remove_cvref_t; + + static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + + // Per-thread register array (VGPRs) + thread_buffer thread_buf_; +}; +``` + +**The tile distribution defines:** +- **Which elements each thread owns** in the tile +- **How many elements** each thread stores (buffer size) +- **How elements are laid out** in each thread's registers + +**Concrete Example for 256×32 tile with 256 threads:** + +``` +Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values) +Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values) +Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values) +... +Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values) +``` + +**Collectively:** +- All 256 threads together hold the **entire 256×32 tile** (8192 elements) +- Each thread's buffer lives in its **own VGPRs** +- No two threads own the same element + +**Distributed Ownership Analogy:** +Think of a tile as a **jigsaw puzzle**: +- The **tile distribution** is the cutting pattern +- Each **thread** gets one puzzle piece (its slice) +- Each **`static_distributed_tensor`** is a box holding all pieces +- Each thread's **`thread_buf_`** is its individual piece in its own registers + +--- + +### Step 4: The GEMM Loop + +```cpp +// Initialize C accumulator to zero +auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; +tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +index_t iCounter = num_loop; // Number of K iterations + +while(iCounter > 0) +{ + // 1. Load from DRAM to registers + a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs + b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs + + // 2. Move windows for next iteration + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32) + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32) + + // 3. Store from registers to LDS + store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS + store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS + + // 4. Synchronize threads (ensure all data is in LDS) + block_sync_lds(); + + // 5. Compute GEMM using data in LDS + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + // 6. Synchronize threads (before overwriting LDS in next iteration) + block_sync_lds(); + + iCounter--; +} + +return c_block_tile; // Return accumulated result in registers +``` + +--- + +## Detailed Loop Breakdown + +### Phase 1: Load (DRAM → VGPRs) + +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +**What happens:** +1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution) +2. Data is loaded into **per-thread registers** (VGPRs) +3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once) + +**Example for Thread 0:** +``` +Thread 0 loads: + A[0,0:7] (8 fp16 values, one vector load) + A[1,0:7] (8 fp16 values, one vector load) + ... +``` + +### Phase 2: Move Windows + +```cpp +constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); +move_tile_window(a_copy_dram_window, a_dram_tile_window_step); +``` + +**What happens:** +- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example) +- This prepares for the next K iteration +- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ... + +**Visualization for Problem Size 512×256×64:** +``` +Matrix A (512×64): +┌─────────────────────────────────────┐ +│ Block 0: rows 0-255 │ +│ ┌──────────┬──────────┐ │ +│ │ K=0:31 │ K=32:63 │ │ ← Window slides right +│ │ Iter 0 │ Iter 1 │ │ +│ └──────────┴──────────┘ │ +└─────────────────────────────────────┘ +``` + +### Phase 3: Store (VGPRs → LDS) + +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +**What happens:** +1. Each thread writes **its elements** from registers to LDS +2. Uses the **same distribution** as the DRAM load +3. Data is now in **shared memory**, accessible to all threads in the block + +**Why this step?** +- GEMM computation needs **all threads** to access **all data** +- Registers are per-thread; LDS is shared across the block +- LDS acts as a "staging area" for collaborative computation + +### Phase 4: Synchronize + +```cpp +block_sync_lds(); +``` + +**What happens:** +- All threads in the block **wait** until everyone has finished storing to LDS +- Ensures no thread starts reading from LDS before all writes are complete +- Critical for correctness! + +### Phase 5: GEMM Computation + +```cpp +block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); +``` + +**What happens:** +1. The warp-level GEMM reads data from LDS +2. Performs matrix multiplication using MFMA instructions +3. Accumulates results into `c_block_tile` (in registers) + +**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results. + +### Phase 6: Synchronize Again + +```cpp +block_sync_lds(); +``` + +**What happens:** +- Ensures all threads have finished reading from LDS +- Safe to overwrite LDS in the next iteration + +--- + +## Memory Flow Diagram + +``` +Iteration 0 (K=0:31): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 0:31] │ │ thread) │ │ │ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (VGPRs) │ + └──────────┘ + +Iteration 1 (K=32:63): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 32:63] │ │ thread) │ │ (reused)│ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (accum.) │ + └──────────┘ +``` + +--- + +## Example: Problem Size 512×256×64 + +### Block 0 Computation + +**Input:** +- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially +- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed) +- `num_loop`: 2 (since K=64, KPerBlock=32) + +**Iteration 0:** +1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs +2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63] +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T` + +**Iteration 1:** +1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs +2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends) +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T` + +**Output:** +- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers + +--- + +## Key Concepts Summary + +### 1. Tile Distribution +- **Maps threads to data elements** for load/store operations +- Each thread knows exactly which elements it's responsible for +- Enables **parallel, vectorized** memory access +- **Same distribution** used for DRAM load and LDS store + +### 2. Static Distributed Tensor +- **Per-thread register storage** (VGPRs) +- Each thread owns a **different slice** of the tile +- **Compile-time sized** for zero-overhead abstraction +- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile` + +### 3. Tile Window Movement +- Windows **slide** over larger tensors +- Enables iteration over the K dimension +- `move_tile_window(window, step)` updates the origin + +### 4. LDS as Staging Area +- **Shared memory** accessible to all threads in a block +- Required because GEMM needs all threads to access all data +- **Reused** across K iterations (same LDS buffer) + +### 5. Synchronization +- `block_sync_lds()` ensures memory consistency +- **Before GEMM**: All stores to LDS are complete +- **After GEMM**: All reads from LDS are complete + +--- + +## Deep Dive: `static_distributed_tensor` Mechanics + +### How Tile Distribution Creates Per-Thread Storage + +When you call: +```cpp +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +ABlockTile a_block_tile; +``` + +**Step 1: Extract Thread Tensor Descriptor** + +The tile distribution contains a `ys_to_d_descriptor` that maps: +- **Y dimensions** (logical tile coordinates, e.g., M, K) +- **D dimension** (per-thread register index, linearized) + +```cpp +using ThreadTensorDesc = + decltype(StaticTileDistribution{}.get_ys_to_d_descriptor()); +``` + +**Step 2: Calculate Per-Thread Buffer Size** + +```cpp +static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + +static constexpr index_t get_thread_buffer_size() +{ + return kThreadElementSpaceSize / PackedSize; +} +``` + +**Example:** +- 256×32 tile distributed across 256 threads +- Each thread owns 32 elements (one row) +- `thread_buffer_size = 32` (for PackedSize=1) + +**Step 3: Allocate Thread Buffer** + +```cpp +thread_buffer thread_buf_; +``` + +This is essentially: +```cpp +fp16_t data[32]; // Per-thread register array (VGPRs) +``` + +### Usage in Load/Store Operations + +**Load from DRAM:** +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +What happens internally: +1. Each thread queries the tile distribution: "Which elements do I own?" +2. Thread 0 learns it owns A[0,0:31] +3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]` +4. All 256 threads do this **in parallel** + +**Store to LDS:** +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +What happens internally: +1. Each thread reads from its `a_block_tile.thread_buf_` +2. Thread 0 writes A[0,0:31] from its registers to LDS +3. All 256 threads do this **in parallel** +4. After `block_sync_lds()`, the entire tile is in shared LDS + +### Distributed Indexing + +The `static_distributed_tensor` supports compile-time indexing: + +```cpp +// Access using distributed indices +auto value = a_block_tile(tile_distributed_index{}); +``` + +Internally: +1. Convert distributed index → Y index (logical tile coordinates) +2. Calculate buffer offset using `ThreadTensorDesc` +3. Access `thread_buf_[offset]` + +All of this happens **at compile time** with zero runtime overhead! + +### Why This Design? + +**Benefits:** +1. **Parallel Memory Access**: All threads load/store simultaneously +2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once) +3. **Zero Overhead**: All indexing resolved at compile time +4. **Type Safety**: Distribution mismatch caught at compile time +5. **Register Pressure**: Compiler knows exact VGPR usage + +**Trade-offs:** +- Requires compile-time tile sizes +- Distribution must be static +- More complex type system + +### Memory Hierarchy Summary + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) │ +│ Full matrices A, B, C │ +└─────────────────────────────────────────────────────────────┘ + │ + │ load_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Per-Thread Registers) │ +│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │ +│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │ +│ ... │ +│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │ +│ │ +│ ← static_distributed_tensor manages this distribution │ +└─────────────────────────────────────────────────────────────┘ + │ + │ store_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) │ +│ Entire block tile (256×32) │ +│ Accessible to all threads in block │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key Insight:** +`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time. + + + diff --git a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt new file mode 100644 index 0000000000..e16977921a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp) + +target_compile_options(tile_tutorial_naive_gemm PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) + +add_dependencies(tutorials tile_tutorial_naive_gemm) \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..43cb01fb36 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md @@ -0,0 +1,618 @@ +# Host-Level Pipeline: Orchestrating Block-Level GEMM + +This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation. + +## Overview + +The host-level pipeline is responsible for: +1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C +2. **Block-to-tile mapping**: Assigning each thread block to a specific tile +3. **Creating tile windows**: Establishing sliding windows over tensor views +4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM +5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM + +```cpp +template +struct PracticeGemmHostPipeline +{ + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram) const + { + // 1. Calculate problem dimensions and tile coverage + // 2. Map thread block to tile coordinates + // 3. Create tile windows over A and B + // 4. Call block-level pipeline to compute + // 5. Store result to C + } +}; +``` + +--- + +## Step 1: Calculate Problem Dimensions and Tile Coverage + +```cpp +// Size of the entire problem +const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K +const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N +const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + +// Size of the block tile +const auto MPerBlock = BlockTile::at(number<0>{}); // 256 +const auto NPerBlock = BlockTile::at(number<1>{}); // 128 +const auto KPerBlock = BlockTile::at(number<2>{}); // 32 + +// Number of block tiles needed to cover C matrix +const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2 +const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2 +``` + +### What's Happening: + +1. **Extract problem dimensions** from tensor descriptors: + - `M = 512`: Rows in A and C + - `N = 256`: Columns in B and C + - `K = 64`: Inner dimension (columns of A, rows of B) + +2. **Get block tile sizes** from the `BlockTile` configuration: + - `MPerBlock = 256`: Each block processes 256 rows + - `NPerBlock = 128`: Each block processes 128 columns + - `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration + +3. **Calculate tile coverage**: + - `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction + - `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction + - **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**! + +### Visual Representation: + +``` +Matrix C (512 × 256): +┌──────────────────────┬──────────────────────┐ +│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2 +│ 256×128 │ 256×128 │ +│ Block 0 │ Block 1 │ +│ │ │ +├──────────────────────┼──────────────────────┤ +│ Tile (1,0) │ Tile (1,1) │ +│ 256×128 │ 256×128 │ +│ Block 2 │ Block 3 │ +│ │ │ +└──────────────────────┴──────────────────────┘ + ↑ + num_tile_m = 2 + +Total blocks needed = 2 × 2 = 4 blocks + +Each block computes one 256×128 tile of the output matrix C. +``` + +### How Blocks Cover Matrices A and B: + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬──────┐ ┌─────────────┬──────┐ +│ Block 0,2 │ K │ │ Block 0,1 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 0-255 │ │ │ 0-127 │ │ +├─────────────┼──────┤ ├─────────────┼──────┤ +│ Block 1,3 │ K │ │ Block 2,3 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 256-511 │ │ │ 128-255 │ │ +└─────────────┴──────┘ └─────────────┴──────┘ + 256 rows 64 cols 128 rows 64 cols + +Each block needs to iterate over K dimension (64/32 = 2 iterations) +``` + +--- + +## Step 2: Map Thread Block to Tile Coordinates + +```cpp +// Get block id (0 to total_blocks - 1) +const auto id_block = get_block_id(); + +// Map block id to 2D tile coordinates +const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); +const auto tile_id = block2tile(id_block); + +const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate +const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate +``` + +### What's Happening: + +Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates. + +### The `MakeBlock2TileMap` Function: + +```cpp +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + // Create a merge transform: (N0, M0) → linear index + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + // Convert linear block_id back to 2D coordinates + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + // Return (m_idx, n_idx) - note the swap! + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; +} +``` + +### In Our Example (2×2 Grid): + +```cpp +// Block 0: +id_block = 0 +tile_id = block2tile(0) = (0, 0) // Top-left tile +tile_id_m = 0, tile_id_n = 0 + +// Block 1: +id_block = 1 +tile_id = block2tile(1) = (1, 0) // Bottom-left tile +tile_id_m = 1, tile_id_n = 0 + +// Block 2: +id_block = 2 +tile_id = block2tile(2) = (0, 1) // Top-right tile +tile_id_m = 0, tile_id_n = 1 + +// Block 3: +id_block = 3 +tile_id = block2tile(3) = (1, 1) // Bottom-right tile +tile_id_m = 1, tile_id_n = 1 +``` + +**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing! + +--- + +## Step 3: Calculate Tile Origin and Create Tile Windows + +```cpp +// Calculate the starting position of this tile in the global matrix +const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256 +const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128 + +// Create tile windows over A and B tensor views +const auto a_block_window = make_tile_window( + a_dram, // Tensor view over A + make_tuple(number{}, number{}), // Window size: 256×32 + {tile_origin_m, 0} // Origin: varies by block +); + +const auto b_block_window = make_tile_window( + b_dram, // Tensor view over B + make_tuple(number{}, number{}), // Window size: 128×32 + {tile_origin_n, 0} // Origin: varies by block +); +``` + +### Tile Origins for Each Block: + +```cpp +// Block 0 (Tile 0,0): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 1 (Tile 1,0): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 2 (Tile 0,1): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (128, 0) → covers B rows 128-255 + +// Block 3 (Tile 1,1): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (128, 0) → covers B rows 128-255 +``` + +### What are Tile Windows? + +A **tile window** is a **sliding window** over a larger tensor view. It: +- Defines a **rectangular region** within the tensor +- Has a **fixed size** (e.g., 256×32 for A) +- Has an **origin** (starting position) +- Can be **moved** to access different regions +### Visual Representation (Block 0 Example): + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │ +│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │ +│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │ +│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │ +│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │ +│ │ │ ├─────────────┼─────────────┤ +├─────────────┼─────────────┤ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +└─────────────┴─────────────┘ └─────────────┴─────────────┘ + Origin: (0, 0) Origin: (0, 0) + Covers rows 0-255 Covers rows 0-127 + Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration) +``` + +**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration. + +### Tile Window Properties: + +```cpp +// Tile window structure (conceptual): +struct tile_window { + TensorView& tensor_view; // Reference to underlying tensor + Tuple window_lengths; // Size of the window (256, 32) + MultiIndex window_origin; // Starting position (0, 0) + + // Can move the window: + void move(MultiIndex step); // Shift window by step + + // Access data through the window: + auto load(); // Load data from windowed region +}; +``` + + +### Tile Window Movement: Iterating Over K Dimension + +In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension: + +``` +Matrix A (512 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K +│ ┃ 256×32 ┃ │ ║ 256×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ │ │ +│ Block 1's │ │ +│ region │ │ +└─────────────┴─────────────┘ + +Matrix B (256 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ +│ ┃ 128×32 ┃ │ ║ 128×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ Block 2's │ │ +│ region │ │ +└─────────────┴─────────────┘ +``` + +### How Windows Move (Conceptual - handled by block pipeline): + +```cpp +// Iteration 0: +a_block_window origin: (tile_origin_m, 0) // K columns 0-31 +b_block_window origin: (tile_origin_n, 0) // K columns 0-31 +// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] + +// Move windows to next K position: +move_tile_window(a_block_window, {0, 32}); +move_tile_window(b_block_window, {0, 32}); + +// Iteration 1: +a_block_window origin: (tile_origin_m, 32) // K columns 32-63 +b_block_window origin: (tile_origin_n, 32) // K columns 32-63 +// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] + +// Final result: +// C_tile = C_partial_0 + C_partial_1 +``` + +**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C. + +--- + +## Step 4: Delegate to Block-Level Pipeline + +```cpp +// Get the block-level pipeline from policy +constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + +// Calculate number of K iterations needed +int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2 + +// Allocate shared memory (LDS) for block-level computation +__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + +// Call block-level pipeline to compute C tile +const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); +``` + +### What's Happening: + +1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation +2. **Calculate K iterations**: How many times to iterate over the K dimension + - In our example: `K=64, KPerBlock=32` → **2 iterations** + - Each iteration processes 32 elements of the K dimension + - Results are accumulated across iterations + +3. **Allocate shared memory**: + - `__shared__` declares memory shared by all threads in the block + - `GetStaticLDSSize()` returns the required size in bytes + - This memory is used for: + - Staging data from DRAM → LDS + - Cooperative loading by threads + - Fast access during computation + +4. **Execute block pipeline**: + - Takes A and B tile windows as input + - Performs the GEMM computation: `C_tile = A_tile × B_tile` + - Returns result in `c_block_tile` (stored in VGPRs - registers) + +### Memory Hierarchy During Computation: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) - Slowest, Largest │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ A matrix │ │ B matrix │ │ C matrix │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load ↑ store +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ A_tile │ │ B_tile │ ← Staged here │ +│ │ (p_smem) │ │ (p_smem) │ │ +│ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ c_block_tile (accumulated result) │ │ +│ │ Computation happens here using MFMA instructions │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Block Pipeline Responsibilities: + +The block pipeline (called here) will: +1. Load A and B tiles from DRAM → LDS (cooperative loading) +2. Distribute work among warps +3. Each warp loads its portion from LDS → VGPRs +4. Perform MFMA operations: `C += A × B` +5. Accumulate results in VGPRs +6. Return final `c_block_tile` in registers + +--- + +## Step 5: Store Results to DRAM + +```cpp +// Create a tile window over C for writing results +auto c_window = make_tile_window( + c_dram, // Tensor view over C + make_tuple(number{}, number{}), // Window size: 256×128 + {tile_origin_m, tile_origin_n} // Origin: varies by block +); + +// Store computed tile from VGPRs to DRAM +store_tile(c_window, c_block_tile); +``` + +### C Window Origins for Each Block: + +```cpp +// Block 0: Writes to top-left tile +c_window origin: (0, 0) → writes to C[0:255, 0:127] + +// Block 1: Writes to bottom-left tile +c_window origin: (256, 0) → writes to C[256:511, 0:127] + +// Block 2: Writes to top-right tile +c_window origin: (0, 128) → writes to C[0:255, 128:255] + +// Block 3: Writes to bottom-right tile +c_window origin: (256, 128) → writes to C[256:511, 128:255] +``` + +### What's Happening: + +1. **Create C tile window**: + - Size: 256×128 (matches our block tile size) + - Origin: Varies by block - each block writes to its assigned region + - This window defines **where** to write the results + +2. **Store tile to DRAM**: + - `c_block_tile`: Computed results in VGPRs (registers) + - `c_window`: Destination window in DRAM + - `store_tile()`: Efficiently writes data from registers → DRAM + +### The `store_tile` Function: + +Recall from our earlier discussion, `store_tile` does: + +```cpp +template +void store_tile(TileWindow& tile_window_tmp, + const DistributedTensor& dstr_tensor) +{ + // 1. Extract tile distribution from distributed tensor + using TileDstr = typename DistributedTensor::TileDistribution; + + // 2. Upgrade simple tile window to one with distribution + auto tile_window = make_tile_window( + tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + TileDstr{} // Add distribution info + ); + + // 3. Store using vectorized writes + tile_window.store(dstr_tensor); +} +``` + +### Memory Flow: + +``` +VGPRs (Registers) DRAM (Global Memory) +┌─────────────────────┐ ┌─────────────────────┐ +│ c_block_tile │ │ C matrix │ +│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │ +│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │ +│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │ +│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │ +│ └───┴───┴───┴───┘ │ │ │ │ │ +│ Distributed across │ │ └───────────────┘ │ +│ threads/warps │ │ Origin: (0, 0) │ +└─────────────────────┘ └─────────────────────┘ + +Each thread writes its portion using vector stores (e.g., float4) +``` + +### Store Optimization: + +The `store_tile` function: +- Uses **vectorized stores** (write multiple elements at once) +- Ensures **coalesced memory access** (adjacent threads write adjacent memory) +- Respects **tile distribution** (each thread knows what data it owns) +- Handles **out-of-bounds** checking (for partial tiles at boundaries) + +--- + +## Complete Flow Visualization + +Let's trace the complete flow for **Block 0** (other blocks follow the same pattern): + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Step 1: Calculate Tile Coverage │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ M=512, N=256, K=64 │ │ +│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │ +│ │ num_tile_m = ceil(512/256) = 2 │ │ +│ │ num_tile_n = ceil(256/128) = 2 │ │ +│ │ Total blocks needed = 2 × 2 = 4 blocks │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 2: Map Block to Tile (Block 0 example) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Block ID: 0 │ │ +│ │ Tile coordinates: (0, 0) - top-left tile │ │ +│ │ Tile origin: (0, 0) │ │ +│ │ │ │ +│ │ (Blocks 1,2,3 get different tile coordinates) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 3: Create Tile Windows │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ a_block_window: 256×32 starting at (0,0) over A │ │ +│ │ b_block_window: 128×32 starting at (0,0) over B │ │ +│ │ Windows initially cover K columns 0-31 │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 4: Execute Block Pipeline (2 K iterations) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Allocate shared memory (LDS) │ │ +│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │ +│ │ │ │ +│ │ K Iteration 0 (K=0-31): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │ +│ │ └─ Move windows: {0, 32} │ │ +│ │ │ │ +│ │ K Iteration 1 (K=32-63): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │ +│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │ +│ │ │ │ +│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 5: Store Results │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Create c_window: 256×128 starting at (0,0) over C │ │ +│ │ store_tile(c_window, c_block_tile) │ │ +│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │ +│ │ │ │ +│ │ Block 0 writes to C[0:255, 0:127] │ │ +│ │ (Other blocks write to their respective regions) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + +All 4 blocks execute in parallel, each computing its assigned 256×128 tile! +``` + +--- + +## Key Concepts Summary + +### 1. **Tile Coverage** +- Determines how many thread blocks are needed +- Each block processes one tile of the output matrix C +- Calculated as `ceil(dimension / tile_size)` + +### 2. **Block-to-Tile Mapping** +- Maps linear block ID to 2D tile coordinates +- Uses column-major ordering for better memory coalescing +- Each block knows which tile it's responsible for + +### 3. **Tile Windows** +- **Sliding windows** over larger tensor views +- Define a rectangular region with fixed size and movable origin +- Provide efficient, structured access to tensor data +- Can be moved to access different regions (e.g., for K iterations) + +### 4. **Memory Hierarchy** +- **DRAM (Global)**: Largest, slowest - stores full matrices +- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access +- **VGPRs (Registers)**: Smallest, fastest - performs computation + +### 5. **Data Flow** +``` +DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM + ↑ ↓ + A, B matrices C matrix +``` + +--- + +## Next Steps + +The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore: +- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed +- **Warp-level pipeline**: How warps perform MFMA operations +- **Memory optimization**: LDS usage, bank conflicts, coalescing + +The host level provides the **orchestration**, while the block and warp levels provide the **execution**! + diff --git a/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md new file mode 100644 index 0000000000..7cd0d06fc5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md @@ -0,0 +1,464 @@ +# PracticeGemmKernel: Understanding the Kernel Entry Point + +This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views. + +## Overview + +The `PracticeGemmKernel` is a templated struct that: +1. Takes raw device memory pointers for matrices A, B, and C +2. Wraps them into **tensor views** - logical, structured views over physical memory +3. Dispatches to the host-level pipeline for computation + +```cpp +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + // Step 1: Create tensor views over raw memory + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + // Step 2: Dispatch to host-level pipeline + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; +``` + +--- + +## What are Tensor Views? + +A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor. + +### Key Components of a Tensor View: + +1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers) +2. **Raw Pointer**: Points to the actual data in memory +3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A) +4. **Strides**: How to navigate through memory to access elements +5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction +6. **Guaranteed Vector Stride**: The stride of those vectorizable elements + +--- + +## The Memory Abstraction Hierarchy + +CK Tile uses a three-layer abstraction to go from raw memory to structured tensors: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Layer 3: TENSOR VIEW │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ • Logical multi-dimensional structure │ │ +│ │ • Shape: (M, K) = (256, 32) │ │ +│ │ • Strides: (32, 1) for row-major layout │ │ +│ │ • Provides: operator[], coordinate-based access │ │ +│ │ • Knows: How to map (i,j) → linear offset │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 2: BUFFER VIEW │ │ +│ │ ┌─────────────────────────────────────────────────────┐ │ │ +│ │ │ • Linear view of memory │ │ │ +│ │ │ • Pointer: p_data_ → device memory │ │ │ +│ │ │ • Size: Total number of elements │ │ │ +│ │ │ • Address space: global/LDS/generic │ │ │ +│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │ +│ │ └─────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 1: RAW PHYSICAL MEMORY │ │ +│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │ +│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │ +│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │ +│ │ ↑ │ │ +│ │ p_a (raw pointer from hipMalloc) │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Deep Dive: `make_naive_tensor_view` + +Let's break down the function call for matrix A: + +```cpp +auto a_dram = make_naive_tensor_view( + p_a, // Raw pointer to device memory + make_tuple(M, K), // Shape: (256, 32) + make_tuple(stride_a, 1), // Strides: (32, 1) - row-major + number<8>{}, // Guaranteed vector length + number<1>{} // Guaranteed vector stride +); +``` + +### Function Signature: + +```cpp +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view(DataType* __restrict__ p, + const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + // Step 1: Create tensor descriptor (shape + stride information) + auto desc = make_naive_tensor_descriptor(lengths, + strides, + number{}, + number{}); + + // Step 2: Create buffer view (pointer + size + address space) + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); + + // Step 3: Combine into tensor view + return tensor_view{buffer_view, desc}; +} +``` + +--- + +## Parameter Breakdown + +### 1. **Template Parameter: `address_space_enum::global`** + +Specifies where the memory lives: +- `global`: GPU global memory (DRAM) - slowest but largest +- `lds`: Local Data Share (shared memory) - fast, limited size +- `generic`: Generic address space +- `vgpr`: Vector General Purpose Registers - fastest, smallest + +In our case, `global` means the data is in GPU DRAM. + +### 2. **`p_a` - Raw Pointer** + +The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data. + +### 3. **`make_tuple(M, K)` - Shape/Lengths** + +Defines the logical dimensions of the tensor: +- For matrix A: `(256, 32)` means 256 rows, 32 columns +- This is the **logical view**, independent of how data is physically laid out + +### 4. **`make_tuple(stride_a, 1)` - Strides** + +Defines how to navigate through memory: +- **Stride for dimension 0 (rows)**: `stride_a = K = 32` + - To move to the next row, skip 32 elements +- **Stride for dimension 1 (columns)**: `1` + - To move to the next column, skip 1 element + +**Row-major layout example:** +``` +Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...] + ↑ ↑ + Row 0 starts here Row 1 starts here (offset = 32) + +To access element A[i][j]: + offset = i * stride_a + j * 1 + = i * 32 + j +``` + +### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length** + +This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."** + +#### Why is this important? + +Modern GPUs can load multiple elements in one instruction (vectorized loads): +- `float4`: Load 4 floats at once +- `float8`: Load 8 floats at once (if supported) + +By specifying `number<8>{}`, we're telling the system: +- "You can safely use vector loads of up to 8 elements" +- "The memory alignment and layout support this" + +**Example:** +```cpp +// Without vectorization (slow): +for (int j = 0; j < 8; j++) { + data[j] = memory[offset + j]; // 8 separate loads +} + +// With vectorization (fast): +float8 vec = *reinterpret_cast(&memory[offset]); // 1 load! +``` + +### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride** + +This specifies the **stride between consecutive vectorizable elements** in the last dimension. + +- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)" +- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively + +**Why does this matter?** + +For efficient vectorized loads, elements must be: +1. **Contiguous** (stride = 1) ✓ +2. **Aligned** properly in memory +3. **Within the same cache line** (ideally) + +If the stride were `2`, it would mean: +``` +A[i][0] is at offset 0 +A[i][1] is at offset 2 (not 1!) +A[i][2] is at offset 4 +``` +This would prevent efficient vectorization. + +--- + +## What is a Buffer View? + +A **buffer view** is the middle layer between raw memory and tensor view. It provides: + +### Core Responsibilities: + +1. **Memory Management** + - Holds the raw pointer: `T* p_data_` + - Tracks buffer size: `BufferSizeType buffer_size_` + - Knows the address space: `global`, `lds`, etc. + +2. **Vectorized Access** + ```cpp + template + CK_TILE_DEVICE VectorType get(index_t offset); + ``` + - Provides efficient vector loads/stores + - Handles alignment requirements + +3. **Bounds Checking** (optional) + ```cpp + template + CK_TILE_DEVICE auto get(index_t i, index_t linear_offset); + ``` + - Can optionally check if access is within bounds + - Returns invalid value (default 0) for out-of-bounds access + +4. **Address Space Awareness** + - Uses different load/store instructions based on address space + - Global memory: `global_load`, `global_store` + - LDS: `ds_read`, `ds_write` + +### Buffer View Structure: + +```cpp +template +struct buffer_view +{ + T* p_data_; // Raw pointer + BufferSizeType buffer_size_; // Total elements + remove_cvref_t invalid_element_value_; // Value for OOB access + + // Access operators + const T& operator[](index_t i) const; // Read + T& operator()(index_t i); // Write + + // Vectorized access + template + VectorType get(index_t offset); +}; +``` + +--- + +## Visual Example: Matrix A Memory Layout + +Let's visualize how matrix A (256×32, fp16) is organized: + +### Raw Physical Memory (Linear): +``` +GPU DRAM Address Space: +┌─────────────────────────────────────────────────────────────────┐ +│ Byte 0 │ +│ ↓ │ +│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │ +│ ↑ ↑ │ +│ Row 0 (32 elements) Row 1 (32 elements) │ +│ │ +│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │ +└─────────────────────────────────────────────────────────────────┘ + ↑ + p_a (raw pointer) +``` + +### Buffer View Layer: +``` +buffer_view +┌─────────────────────────────────────────────────────────────────┐ +│ p_data_ = p_a │ +│ buffer_size_ = 256 × 32 = 8,192 elements │ +│ address_space = global (DRAM) │ +│ │ +│ Provides: │ +│ • Linear indexing: buffer_view[i] → element at offset i │ +│ • Vectorized loads: get(offset) → load 4 fp16s at once│ +│ • Bounds checking: is offset < buffer_size_? │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Tensor View Layer: +``` +tensor_view +┌─────────────────────────────────────────────────────────────────┐ +│ Shape: (256, 32) │ +│ Strides: (32, 1) │ +│ Guaranteed vector length: 8 │ +│ Guaranteed vector stride: 1 │ +│ │ +│ Logical 2D View: │ +│ Col: 0 1 2 ... 31 │ +│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│ +│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │ +│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │ +│ ... │ +│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │ +│ │ +│ Provides: │ +│ • Multi-dimensional indexing: A[i][j] │ +│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │ +│ • Tile window creation: Extract sub-tensors │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Complete Flow: Raw Memory → Tensor View + +Let's trace the complete transformation for matrix A: + +### Step 1: Kernel Launch (Host Side) +```cpp +// On host: Allocate device memory +hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer + +// Launch kernel +kernel<<>>(p_a, p_b, p_c, M, N, K, ...); +``` + +### Step 2: Inside Kernel (Device Side) +```cpp +// Receive raw pointer +const fp16_t* p_a; // Points to GPU DRAM + +// Step 2a: Create tensor descriptor +auto desc = make_naive_tensor_descriptor( + make_tuple(256, 32), // Shape + make_tuple(32, 1), // Strides + number<8>{}, // Vector length + number<1>{} // Vector stride +); +// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8" + +// Step 2b: Create buffer view +auto buffer_view = make_buffer_view( + p_a, // Raw pointer + 256 * 32 // Total elements +); +// buffer_view now wraps p_a with size and address space info + +// Step 2c: Create tensor view +auto a_dram = tensor_view{buffer_view, desc}; +// a_dram now provides structured, multi-dimensional access to p_a +``` + +### Step 3: Using the Tensor View +```cpp +// Access element A[i][j] +auto value = a_dram[make_tuple(i, j)]; + +// Create a tile window (sub-tensor) +auto tile = make_tile_window( + a_dram, + make_tuple(16, 16), // 16×16 tile + make_tuple(0, 0) // Starting at origin +); + +// Load tile into registers with vectorization +auto tile_data = load_tile(tile); // Uses vector loads internally! +``` + +--- + +## Why This Abstraction? + +### Benefits: + +1. **Type Safety**: Can't accidentally access wrong dimensions +2. **Performance**: Compiler knows about vectorization opportunities +3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers) +4. **Maintainability**: Logical structure separate from physical layout +5. **Optimization**: Guaranteed vector properties enable aggressive optimizations + +### Example: Without Tensor Views (Manual Indexing) +```cpp +// Ugly, error-prone, hard to optimize: +for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j]; + // Hope the compiler vectorizes this? 🤞 + } +} +``` + +### Example: With Tensor Views (Clean, Optimized) +```cpp +// Clean, safe, automatically vectorized: +auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin); +auto tile_data = load_tile(tile); // Vectorized loads guaranteed! +``` + +--- + +## Summary + +The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction: + +1. **Raw Memory**: Linear array of bytes in GPU DRAM +2. **Buffer View**: Adds size, address space, and vectorized access +3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing + +This abstraction enables: +- ✅ Clean, readable code +- ✅ Type-safe multi-dimensional access +- ✅ Automatic vectorization +- ✅ Flexible memory space handling +- ✅ Efficient tile-based computation + +The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation! + diff --git a/tutorial/ck_tile/01_naive_gemm/README.md b/tutorial/ck_tile/01_naive_gemm/README.md new file mode 100644 index 0000000000..f2caf7d993 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/README.md @@ -0,0 +1,150 @@ +# CK Tile Practice GEMM Example + +This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system. + +## CK Tile API Structure + +In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**: + +1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices +2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads +3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue) + +## Overview + +This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing: + +- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch +- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM +- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory +- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation + +## Problem Setup and Data Flow + +### Problem Size Configuration +We set the problem size using the M, N and K variables: +```cpp +ck_tile::index_t M = 1024; // Number of rows in A and C +ck_tile::index_t N = 512; // Number of columns in B and C +ck_tile::index_t K = 256; // Number of columns in A, rows in B +``` + +### Host Matrix Creation +Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory: +```cpp +// Host tensors with proper strides +ck_tile::HostTensor a_host(a_lengths, a_strides); // M × K +ck_tile::HostTensor b_host(b_lengths, b_strides); // N × K +ck_tile::HostTensor c_host(c_lengths, c_strides); // M × N + +// Initialize with random data +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + +// Allocate device memory and transfer data +ck_tile::DeviceMem a_device(a_host); +a_device.ToDevice(a_host.data()); +``` + +### PracticeGemmShape Configuration +A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile: + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block +using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave +``` +- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix. + + +- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile. +- BlockTiles are further subdivided into WarpTiles. +- WarpTiles over A and B similarly work together to calculate the WarpTile of C. + +### Problem and Policy Composition +```cpp +// A Problem is composed from Shape and info about the data +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +// A Policy is created describing data-to-thread mapping +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + +// A Kernel is then composed of Problem and Policy +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` + +### Kernel Launch +`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`: +```cpp +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel( + gemm_kernel{}, // Kernel composed of Problem + Policy + kGridSize, // Grid dimensions + kBlockSize, // Block dimensions + 0, // Dynamic shared memory + // Kernel arguments: device buffers and problem dimensions + a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(), + M, N, K, stride_a, stride_b, stride_c)); +``` + +### Result Verification +The results from the kernel are compared with results from CPU based computation function: +```cpp +// CPU reference implementation +ck_tile::HostTensor c_host_ref(c_lengths, c_strides); +reference_basic_gemm(a_host, b_host, c_host_ref); + +// Device results +ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + +// Verify correctness +bool pass = ck_tile::check_err(c_host_dev, c_host_ref); +``` + +### Runtime Flow + +The main program (`practice_gemm.cpp`) is the entry point for the runtime flow: + +```cpp +int main() +{ + // 1. Define data types and problem sizes + using ADataType = ck_tile::half_t; + ck_tile::index_t M = 2048, N = 1024, K = 512; + + // 2. Create host tensors and initialize + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + + // 3. Allocate device memory and transfer data + ck_tile::DeviceMem a_device(a_host); + + // 4. Configure tile shapes + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + // 5. Launch kernel + using gemm_kernel = ck_tile::PracticeGemmKernel; + float ave_time = ck_tile::launch_kernel(/*...*/); + + // 6. Verify results + bool pass = verify_results(a_host, b_host, c_host); + + // 7. Print performance metrics + print_performance_metrics(ave_time, M, N, K); +} +``` + +## Building and Running + +```bash +# From composable_kernel root directory +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ +make tile_example_practice_gemm -j + +# Run with sample sizes +./bin/tile_example_practice_gemm +``` +This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework. diff --git a/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md new file mode 100644 index 0000000000..d0b8400b9c --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md @@ -0,0 +1,506 @@ +# Practice GEMM: Step-by-Step Code Walkthrough + +This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API. + +## Overview + +We'll implement `C = A × B` where: +- `A` is an `M × K` matrix +- `B` is an `N × K` matrix (note: transposed layout) +- `C` is an `M × N` matrix + +The implementation uses a hierarchical tiling strategy with two levels: +1. **Block Tiles**: Processed by thread blocks +2. **Wave Tiles**: Processed by warps (wavefronts) within blocks + +--- + +## Step 1: Define Data Types + +```cpp +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using CDataType = float; +using AccDataType = float; +``` + +**What's happening:** +- We use `half_t` (FP16) for input matrices A and B. +- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy +- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity +--- + +## Step 2: Define Problem Size + +```cpp +ck_tile::index_t M = 512; +ck_tile::index_t N = 256; +ck_tile::index_t K = 64; +ck_tile::index_t verification = 1; + +ck_tile::index_t stride_a = K; +ck_tile::index_t stride_b = K; +ck_tile::index_t stride_c = N; +``` + +**What's happening:** +- `M = 512`: Number of rows in A and C +- `N = 256`: Number of columns in B and C +- `K = 64`: Inner dimension (columns of A, rows of B) +- Strides define memory layout (row-major for A and C, transposed for B) + +**Memory Layout:** +``` +Matrix A (M×K): Matrix B (N×K): Matrix C (M×N): +[512 rows] [256 rows] [512 rows] +[64 cols] [64 cols] [256 cols] +stride = K stride = K stride = N +``` + +--- + +## Step 3: Create Host Tensors + +```cpp +auto a_lengths = std::array{M, K}; +auto b_lengths = std::array{N, K}; +auto c_lengths = std::array{M, N}; + +auto a_strides = std::array{stride_a, 1}; +auto b_strides = std::array{stride_b, 1}; +auto c_strides = std::array{stride_c, 1}; + +ck_tile::HostTensor a_host(a_lengths, a_strides); +ck_tile::HostTensor b_host(b_lengths, b_strides); +ck_tile::HostTensor c_host(c_lengths, c_strides); +``` + +**What's happening:** +- We create three tensors on the host (CPU) memory +- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`) +- `HostTensor` is a CK Tile utility class that manages CPU memory + +**Stride explanation:** +- For A: `stride_a = K` means moving to the next row requires skipping K elements +- For B: `stride_b = K` means B is stored in transposed format +- For C: `stride_c = N` means row-major layout + +--- + +## Step 4: Initialize Tensors with Random Data + +```cpp +ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); +c_host.SetZero(); +``` + +**What's happening:** +- A and B are filled with random values in the range [-5.0, 5.0] +- C is initialized to zero (will store the output) + +**Optional: Print Tensor Contents** +```cpp +// Commented out in the code, but available for debugging: +// a_host.print_first_n(10); // Print first 10 elements of A +``` + +The `print_first_n()` helper function can display tensor contents for debugging purposes. + +--- + +## Step 5: Allocate Device Memory and Transfer Data + +```cpp +ck_tile::DeviceMem a_device(a_host); +ck_tile::DeviceMem b_device(b_host); +ck_tile::DeviceMem c_device(c_host); +``` + +**What's happening:** +- `DeviceMem` allocates GPU memory matching the size of host tensors +- The constructor **automatically transfers data from host to device** +- This is a convenience wrapper around `hipMalloc` and `hipMemcpy` + +**Memory Flow:** +``` +CPU (Host) GPU (Device) +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +│ c_host │ ────────> │c_device │ +└─────────┘ └─────────┘ +``` + +--- + +## Step 6: Configure Hierarchical Tiling + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; +using WaveTile = ck_tile::sequence<16, 16, 16>; +``` + +**What's happening:** +- We define a two-level tiling hierarchy for the GEMM computation + +### Block Tile (256 × 128 × 32) +- **256**: M dimension per block (rows of A and C) +- **128**: N dimension per block (columns of B and C) +- **32**: K dimension per block (inner dimension) +- Each block tile is processed by one **thread block** (256 threads) + +### Wave Tile (16 × 16 × 16) +- **16 × 16**: Output tile dimensions (M × N) per warp iteration +- **16**: K dimension per warp iteration +- Each wave tile is processed by one **warp** (64 threads on AMD GPUs) + +**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile + +**Important Note:** +In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem. + +### Tiling Visualization: + +#### Matrix A (M × K = 256 × 32): +``` +┌─────────────────────────────────────┐ +│ One Block Tile (256 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles (16×16) │ +│ │ 16│ 16 │ in M×K space │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ │ +│ │ .. │ .. │ 16 tiles in M │ +│ ├────┼────┤ 2 tiles in K │ +│ │ │ │ │ +│ └────┴────┘ │ +│ │ +└─────────────────────────────────────┘ +``` + +#### Matrix B (N × K = 128 × 32): +``` +┌──────────────────────────────┐ +│ One Block Tile (128 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles │ +│ │ 16│ 16 │ (16×16) │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ 8 tiles in N │ +│ │ .. │ .. │ 2 tiles in K │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ └────┴────┘ │ +└──────────────────────────────┘ +``` + +#### Matrix C (M × N = 256 × 128) - Output: +``` +┌─────────────────────────────────────────────────┐ +│ One Block Tile (256 × 128) │ +│ │ +│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │ +│ │16× │ │ │ │ │ │ │ │ │ +│ │ 16 │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ └────┴────┴────┴────┴────┴────┴────┴────┘ │ +│ │ +│ 16 wave tiles in M direction │ +│ 8 wave tiles in N direction │ +│ Total: 128 wave tiles (16×16 each) │ +└─────────────────────────────────────────────────┘ +``` + +#### How Wave Tiles Combine (C = A × B): +``` +Matrix A Matrix B (stored transposed N×K) Matrix C +(256×32) (128×32) (256×128) + +Row of A tiles: Row of B tiles: One wave tile in C: +┌────┬────┐ ┌────┬────┐ ┌────┐ +│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16) +└────┴────┘ └────┴────┘ └────┘ + 16×16 each 16×16 each + +Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ + ↑ ↑ + K=0..15 K=16..31 + +Each wave tile in C is computed by: +- Taking one row of wave tiles from A (2 tiles along K) +- Taking one row of wave tiles from B (2 tiles along K) + Note: B is stored transposed (N×K), so a "row" in storage corresponds + to a "column" in the logical B^T matrix used in computation +- Performing dot product: Σ(A_k × B_k^T) for k=0,1 +``` + +**Key Insight:** +- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B +- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory +- This is the fundamental operation repeated across all 128 wave tiles in C +- Each warp computes one wave tile using MFMA instructions + +--- + +## Step 7: Create Shape, Problem, and Policy Structs + +```cpp +using PracticeGemmShape = ck_tile::PracticeGemmShape; +std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; +``` + +**What's happening:** + +### 1. **Shape Struct** +Encapsulates all tile shape information (BlockTile and WaveTile dimensions). + +### 2. **Problem Struct** +Holds complete problem description: +- Data types (ADataType, BDataType, CDataType, AccDataType) +- Shape information (BlockTile, WaveTile) + +In more complex examples, this would also include: +- Data layouts (row-major, column-major) +- Mathematical operations (e.g., transposed GEMM) + +### 3. **Policy Struct** +Describes data movement and thread-to-data mapping: +- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions +- In more complex kernels, includes: + - DRAM access patterns + - LDS (Local Data Share) usage strategies + - Thread distribution within blocks + +**CK Tile Design Pattern:** +``` +Kernel = Problem + Policy + Epilogue + ↑ ↑ ↑ + (What) (How) (Post-processing) +``` + +--- + +## Step 8: Calculate Grid and Block Dimensions + +```cpp +ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + +std::cout << "kGridSize: " << kGridSize << std::endl; + +constexpr ck_tile::index_t kBlockSize = 256; +constexpr ck_tile::index_t kBlockPerCU = 1; +``` + +**What's happening:** + +### Grid Size Calculation +```cpp +kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N) + = ceil(512 / 256) × ceil(256 / 128) + = 2 × 2 + = 4 thread blocks +``` + +Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction). + +### Block Configuration +- `kBlockSize = 256`: Each thread block has 256 threads + - 256 threads / 64 threads per warp = **4 warps per block** +- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity) + +**Thread Hierarchy:** +``` +GPU +└── 1 Thread Block (Grid) + └── 256 Threads + ├── Warp 0 (threads 0-63) + ├── Warp 1 (threads 64-127) + ├── Warp 2 (threads 128-191) + └── Warp 3 (threads 192-255) +``` + +--- + +## Step 9: Create and Launch the Kernel + +```cpp +using gemm_kernel = + ck_tile::PracticeGemmKernel; + +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); +``` + +**What's happening:** + +### 1. Kernel Composition +```cpp +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` +The kernel is composed from Problem and Policy structs, following the CK Tile design pattern. + +### 2. Kernel Launch +`launch_kernel()` is a CK Tile utility that: +- Launches the GPU kernel using HIP runtime +- Measures execution time +- Returns average execution time in milliseconds + +### 3. Launch Parameters +- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled +- **Grid size**: `kGridSize = 1` - number of thread blocks +- **Block size**: `kBlockSize = 256` - threads per block +- **Shared memory**: `0` - no dynamic shared memory in this example +- **Kernel arguments**: Device pointers and problem dimensions + +### 4. Kernel Execution Flow +``` +launch_kernel() calls gemm_kernel.operator()() + ↓ +PracticeGemmKernel::operator() + ↓ +Creates tensor views over device memory + ↓ +Calls block-level pipeline + ↓ +Block pipeline calls warp-level pipeline + ↓ +Warp pipeline calls MFMA instructions + ↓ +Results written back to C matrix +``` + +--- + +## Step 10: Verify Results + +```cpp +auto pass = true; + +if(verification) +{ + // Reference gemm on CPU + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + + // Copy GPU results back to host + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + + // Compare results + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; +} +``` + +**What's happening:** + +### 1. CPU Reference Implementation +```cpp +reference_basic_gemm<...>(a_host, b_host, c_host_ref); +``` +Computes GEMM on CPU using a simple nested loop implementation (ground truth). + +### 2. Copy GPU Results to Host +```cpp +c_device.FromDevice(c_host_dev.mData.data()); +``` +Transfers the computed result from GPU memory back to CPU for comparison. + +### 3. Error Checking +```cpp +ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); +``` +Compares GPU and CPU results element-wise with tolerance: +- **Relative error**: 1e-3 (0.1%) +- **Absolute error**: 1e-3 + +**Verification Flow:** +``` +CPU GPU +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +└─────────┘ └─────────┘ + │ │ + ↓ ↓ +reference_gemm() GPU kernel + │ │ + ↓ ↓ +┌──────────┐ ┌──────────┐ +│c_host_ref│ │c_device │ +└──────────┘ └──────────┘ + │ │ + │ ↓ + │ FromDevice() + │ │ + ↓ ↓ + └────> check_err() <───┘ + │ + ↓ + Pass/Fail +``` + +--- + +## Complete Execution Flow Summary + +``` +1. Define data types (FP16 inputs, FP32 output) + ↓ +2. Set problem size (M=256, N=128, K=32) + ↓ +3. Create host tensors and initialize with random data + ↓ +4. Allocate device memory and transfer data (CPU → GPU) + ↓ +5. Configure hierarchical tiling (BlockTile, WaveTile) + ↓ +6. Create Shape, Problem, and Policy structs + ↓ +7. Calculate grid/block dimensions (1 block, 256 threads) + ↓ +8. Compose and launch kernel (Problem + Policy) + ↓ +9. Execute GEMM on GPU + │ ├─ Block-level pipeline + │ ├─ Warp-level pipeline + │ └─ MFMA instructions + ↓ +10. Verify results (compare GPU vs CPU reference) + ↓ +11. Calculate and print performance metrics + ↓ +12. Return success/failure +``` + +--- \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..31fa4ac3eb --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineAGmemBGmemCreg +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + static constexpr index_t MPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t NPerBlock = BlockTile::at(number<1>{}); + static constexpr index_t KPerBlock = BlockTile::at(number<2>{}); + + static constexpr index_t MPerWave = WaveTile::at(number<0>{}); + static constexpr index_t NPerWave = WaveTile::at(number<1>{}); + static constexpr index_t KPerWave = WaveTile::at(number<2>{}); + + using BlockGemm = + remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers + b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS + store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers + block_sync_lds(); + + iCounter--; + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..99c4379ad8 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp" +#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = Shape_; +}; + +struct PracticeGemmBlockPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline() + { + return PracticeGemmWarpPipelineASmemBSmemCreg{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..ef12634e42 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { +template +struct PracticeGemmHostPipeline +{ + using ADataType = typename Problem_::ADataType; + using BDataType = typename Problem_::BDataType; + using CDataType = typename Problem_::CDataType; + using AccDataType = typename Problem_::AccDataType; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram_ref) const + { + + // Size of the entire problem + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + + // Size of the block tile + const auto MPerBlock = BlockTile::at(number<0>{}); + const auto NPerBlock = BlockTile::at(number<1>{}); + const auto KPerBlock = BlockTile::at(number<2>{}); + + // Number of block tile in the N direction to cover C (resultant) matrix + const auto num_tile_n = integer_divide_ceil(N, NPerBlock); + // Number of block tile in the M direction to cover C (resultant) matrix + const auto num_tile_m = integer_divide_ceil(M, MPerBlock); + + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n); + // printf("total number of tiles: %d\n", num_tile_m * num_tile_n); + // } + + // Get block id + const auto id_block = + get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1 + + // Map block id to tile id + const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto tile_id = block2tile(id_block); + + const auto tile_id_m = tile_id.at(number<0>{}); + const auto tile_id_n = tile_id.at(number<1>{}); + + // if(get_thread_id() == 0 && get_block_id() == 15) + // { + // printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n); + // } + + const auto tile_origin_m = tile_id_m * MPerBlock; + const auto tile_origin_n = tile_id_n * NPerBlock; + + // create a tile window over dram for A and B + const auto a_block_window = make_tile_window( + a_dram, make_tuple(number{}, number{}), {tile_origin_m, 0}); + + const auto b_block_window = make_tile_window( + b_dram, make_tuple(number{}, number{}), {tile_origin_n, 0}); + + constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + + int num_loops_k = integer_divide_ceil(K, KPerBlock); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); + auto c_window = make_tile_window(c_dram, + make_tuple(number{}, number{}), + {tile_origin_m, tile_origin_n}); + store_tile(c_window, c_block_tile); + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..d66c3c8522 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp" +#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmHostProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = remove_cvref_t; +}; + +struct PracticeGemmHostPolicy +{ + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline() + { + using PracticeGemmBlockPipelineProblem_ = + PracticeGemmBlockPipelineProblem; + return PracticeGemmBlockPipelineAGmemBGmemCreg{}; + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp new file mode 100644 index 0000000000..ee2e125e24 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/host.hpp" +#include "practice_gemm.hpp" +#include "reference_gemm.hpp" + +int main() +{ + // TODO: GemmTypeConfig + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using CDataType = float; + using AccDataType = float; + + // ArgParser + ck_tile::index_t M = 512; + ck_tile::index_t N = 256; + ck_tile::index_t K = 64; + ck_tile::index_t verification = 1; + + ck_tile::index_t stride_a = K; + ck_tile::index_t stride_b = K; + ck_tile::index_t stride_c = N; + + auto a_lengths = std::array{M, K}; + auto b_lengths = std::array{N, K}; + auto c_lengths = std::array{M, N}; + + auto a_strides = std::array{stride_a, 1}; + auto b_strides = std::array{stride_b, 1}; + auto c_strides = std::array{stride_c, 1}; + + // tensors on host (cpu) + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host(c_lengths, c_strides); + + // initialize tensors + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); + c_host.SetZero(); + + // Print the tensors using the new print_first_n member function + // std::cout << "Tensor A (first 10 elements): "; + // a_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor B (first 10 elements): "; + // b_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor C (first 10 elements): "; + // c_host.print_first_n(10); + // std::cout << std::endl; + + // Create device tensors of same size as host tensors and copy data + ck_tile::DeviceMem a_device(a_host); + ck_tile::DeviceMem b_device(b_host); + ck_tile::DeviceMem c_device(c_host); + + // TODO: BlockTileConfig + // constexpr ck_tile::index_t warpSize = 64; + constexpr ck_tile::index_t kBlockSize = 256; + + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl; + using PracticeGemmShape = ck_tile::PracticeGemmShape; + std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + + ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + + std::cout << "kGridSize: " << kGridSize << std::endl; + constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU + + std::cout << "kBlockSize: " << kBlockSize << std::endl; + std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + + using gemm_kernel = + ck_tile::PracticeGemmKernel; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); + + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp new file mode 100644 index 0000000000..88879ee221 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp" +#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmShape +{ + using BlockTile = remove_cvref_t; + using WaveTile = remove_cvref_t; + + static constexpr index_t BlockTile_M = BlockTile::at(number<0>{}); + static constexpr index_t BlockTile_N = BlockTile::at(number<1>{}); + static constexpr index_t BlockTile_K = BlockTile::at(number<2>{}); + + static constexpr index_t WaveTile_M = WaveTile::at(number<0>{}); + static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); + static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "practice_gemm_shape", + concat('x', BlockTile_M, BlockTile_N, BlockTile_K), + concat('x', WaveTile_M, WaveTile_N, WaveTile_K)); + // clang-format on + } +}; + +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp new file mode 100644 index 0000000000..8f975be7dc --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1); +} diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..bf058af9c5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp @@ -0,0 +1,195 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmWarpPipelineASmemBSmemCreg +{ + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using WaveGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..2efa2bcc2a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct PracticeGemmWarpPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000..9895f5a71d --- /dev/null +++ b/tutorial/ck_tile/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(00_copy_kernel) +add_subdirectory(01_naive_gemm) +