mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038)
* feat: add tutorial folder with gemm tutorial * chore: move copy kernel from examples folder to tutorial * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: remove handdrawn images * docs: add write ups to explain the gemm kernel * docs: add about block level pipeline and static distributed tensors --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
9
tutorial/ck_tile/00_copy_kernel/CMakeLists.txt
Normal file
9
tutorial/ck_tile/00_copy_kernel/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_executable(tile_tutorial_copy_kernel EXCLUDE_FROM_ALL copy_basic.cpp)
|
||||
|
||||
# Impact: This flag ensures that the compiler doesn't make
|
||||
# assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns.
|
||||
target_compile_options(tile_tutorial_copy_kernel PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
add_dependencies(tutorials tile_tutorial_copy_kernel)
|
||||
315
tutorial/ck_tile/00_copy_kernel/README.md
Normal file
315
tutorial/ck_tile/00_copy_kernel/README.md
Normal file
@@ -0,0 +1,315 @@
|
||||
# CK Tile Framework: Getting Started with Tile Copy Operations
|
||||
|
||||
## Overview
|
||||
|
||||
### Copy Kernel
|
||||
A minimal CK_Tile memory copy implementation demonstrating the basic setup required to write a kernel in CK Tile.
|
||||
This experimental kernel is intended for novice CK developers. It introduces the building blocks of CK Tile and provides a sandbox for experimenting with kernel parameters.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture
|
||||
# (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the copy kernel executable
|
||||
make tile_example_copy -j
|
||||
```
|
||||
This will result in an executable `build/bin/test_copy_basic`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m input matrix rows. (default 64)
|
||||
-n input matrix cols. (default 8)
|
||||
-id wave to use for computation. (default 0)
|
||||
-v validation flag to check device results. (default 1)
|
||||
-prec datatype precision to use. (default fp16)
|
||||
-warmup no. of warmup iterations. (default 50)
|
||||
-repeat no. of iterations for kernel execution time. (default 100)
|
||||
```
|
||||
|
||||
## CK Tile Architecture Components
|
||||
|
||||
The CK Tile framework is built around four key architectural components that work together to define and execute GPU kernels: shape, policy, problem, and pipeline.
|
||||
|
||||
### **1. Shape**
|
||||
Defines the **hierarchical tile structure** and **memory layout** of the kernel:
|
||||
|
||||
```cpp
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
|
||||
```
|
||||
|
||||
**Components:**
|
||||
- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N)
|
||||
- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`)
|
||||
- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`)
|
||||
- **ThreadTile**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
|
||||
|
||||
**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks.
|
||||
|
||||
### **2. Problem**
|
||||
Defines the **data types** and **kernel configuration**:
|
||||
|
||||
```cpp
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
```
|
||||
|
||||
**Components:**
|
||||
- **XDataType**: Input/output data type (e.g., `float`, `half`)
|
||||
- **Shape**: The tile shape defined above
|
||||
|
||||
**Purpose**: Encapsulates **what** the kernel operates on and **how** it's configured.
|
||||
|
||||
### **3. Policy**
|
||||
Defines the **memory access patterns** and **distribution strategies**:
|
||||
|
||||
```cpp
|
||||
using Policy = ck_tile::TileCopyPolicy<Problem>;
|
||||
```
|
||||
|
||||
**Key Functions:**
|
||||
- **MakeDRAMDistribution()**: Defines how threads access DRAM memory.
|
||||
|
||||
**Purpose**: Defines **how** data is accessed and distributed across threads.
|
||||
|
||||
### **4. Pipeline**
|
||||
Defines the **execution flow** and **memory movement patterns**:
|
||||
|
||||
```cpp
|
||||
// Example pipeline stages:
|
||||
// 1. DRAM → Registers (load_tile)
|
||||
// 2. Registers → LDS (store_tile)
|
||||
// 3. LDS → Registers (load_tile with distribution)
|
||||
// 4. Registers → DRAM (store_tile)
|
||||
```
|
||||
|
||||
**Purpose**: Defines the **sequence of operations** and **memory movement strategy**.
|
||||
|
||||
### **Component Interaction**
|
||||
|
||||
```cpp
|
||||
// Complete kernel definition
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Policy = ck_tile::TileCopyPolicy<Problem>;
|
||||
using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
|
||||
```
|
||||
|
||||
**Flow:**
|
||||
1. **Shape** defines the tile structure and work distribution
|
||||
2. **Problem** combines data types with the shape
|
||||
3. **Policy** defines memory access patterns for the problem
|
||||
4. **Kernel** implements the actual computation using all components
|
||||
|
||||
### **Why This Architecture?**
|
||||
|
||||
#### **Separation of Concerns**
|
||||
- **Shape**: Focuses on **work distribution** and **tile structure**
|
||||
- **Problem**: Focuses on **data types** and **configuration**
|
||||
- **Policy**: Focuses on **memory access** and **optimization**
|
||||
- **Pipeline**: Focuses on **execution flow** and **synchronization**
|
||||
|
||||
#### **Reusability**
|
||||
- Same **Shape** can be used with different **Problems**
|
||||
- Same **Policy** can be applied to different **Problems**
|
||||
- **Pipelines** can be reused across different kernels
|
||||
|
||||
#### **Performance Optimization**
|
||||
- **Shape** enables optimal work distribution
|
||||
- **Policy** enables optimal memory access patterns
|
||||
- **Pipeline** enables optimal execution flow
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Hierarchical Tile Structure
|
||||
|
||||
The CK Tile framework organizes work in a hierarchical manner:
|
||||
|
||||
1. **ThreadTile**: Number of contiguous elements processed by a single thread
|
||||
- Enables vectorized memory loads/stores.
|
||||
- Example: `ThreadTile = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
|
||||
- A ThreadTile can be imagined as a thread-level tile
|
||||
|
||||
2. **WaveTile**: Number of elements covered by a single wave (64 threads on CDNA, 32 threads on RDNA)
|
||||
- Must satisfy: `Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize`
|
||||
- This ensures the number of threads needed equals the wave size
|
||||
- Example: `WaveTile = seq<64, 4>` with `ThreadTile = seq<1, 4>` means:
|
||||
- Each thread handles 4 elements (ThreadTile_N = 4)
|
||||
- Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements
|
||||
- Total elements = 256, which requires WaveSize = 64 threads
|
||||
|
||||
3. **BlockTile**: Number of elements covered by one block (typically mapped to one CU)
|
||||
- Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements
|
||||
|
||||
4. **BlockWaves**: Number of concurrent waves active in a block
|
||||
- Typical: 4 waves for heavy workloads (e.g., GEMM)
|
||||
- Limit: up to 1024 threads per block → up to 16 waves (CDNA) or 32 waves (RDNA)
|
||||
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M, 1 along N
|
||||
|
||||
### Wave Repetition
|
||||
|
||||
In many scenarios, the total work (BlockTile) is larger than what the available waves can cover in a single iteration. This requires **wave repetition**:
|
||||
|
||||
```cpp
|
||||
// Calculate how many times a wave needs to repeat to cover the entire block tile
|
||||
static constexpr index_t WaveRepetitionPerBlock_M =
|
||||
Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M);
|
||||
static constexpr index_t WaveRepetitionPerBlock_N =
|
||||
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
|
||||
```
|
||||
|
||||
**Key Insight**: When waves repeat, the effective work per thread becomes `ThreadTile * Repeat`, not just `ThreadTile`.
|
||||
|
||||
## Tile Distribution Encoding
|
||||
|
||||
The tile distribution encoding specifies how work is distributed across threads:
|
||||
|
||||
```cpp
|
||||
constexpr auto outer_encoding =
|
||||
tile_distribution_encoding<sequence<1>, // replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>, // hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // parallelism
|
||||
tuple<sequence<1>, sequence<2, 0>>, // paralleism
|
||||
sequence<1, 2>, // yield
|
||||
sequence<0, 1>>{}; // yield
|
||||
```
|
||||
|
||||
### Encoding Parameters Explained
|
||||
|
||||
- **M0, M1, M2**: Hierarchical distribution along M dimension
|
||||
- M0: Number of wave iterations along M
|
||||
- M1: Number of waves along M
|
||||
- M2: Number of threads per wave along M
|
||||
- **N0, N1**: Distribution along N dimension
|
||||
- N0: Number of threads along N
|
||||
- N1: ThreadTile size (elements per thread)
|
||||
- **Order and layout**: The inner-most (rightmost) dimension is the fastest-changing. Choosing `N1 = ThreadTile_N` maps vector width to contiguous addresses, i.e., row-major access in this example.
|
||||
- **YIELD arguments**: Both `Repeat` and `ThreadTile` because effective work per thread is `ThreadTile * Repeat`
|
||||
|
||||
## Tensor Abstractions
|
||||
|
||||
### Tensor Descriptor
|
||||
Defines the logical structure of a tensor:
|
||||
```cpp
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), // tensor dimensions
|
||||
make_tuple(N, 1), // strides
|
||||
number<ThreadTile_N>{}, // per-thread vector length
|
||||
number<1>{} // guaranteed last dimension vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Tensor View
|
||||
Combines memory buffer with tensor descriptor:
|
||||
```cpp
|
||||
auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, // memory buffer
|
||||
make_tuple(M, N), // dimensions
|
||||
make_tuple(N, 1), // strides
|
||||
number<S::ThreadTile_N>{}, // per-thread vector length
|
||||
number<1>{} // guaranteed last dimension vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Tile Window
|
||||
A view into a specific tile of the tensor with thread distribution:
|
||||
```cpp
|
||||
auto x_window = make_tile_window(
|
||||
x_m_n, // tensor view
|
||||
make_tuple(Block_Tile_M, Block_Tile_N), // tile size
|
||||
{iM, 0}, // tile origin
|
||||
tile_distribution // how work is distributed among threads
|
||||
);
|
||||
```
|
||||
|
||||
## The test_copy_basic Kernel
|
||||
|
||||
### Kernel Structure
|
||||
|
||||
The `TileCopyKernel` implements a basic copy operation from input tensor `x` to output tensor `y`:
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct TileCopyKernel
|
||||
{
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
// 1. Create tensor views
|
||||
// 2. Create tile windows
|
||||
// 3. Iterate over N dimension tiles
|
||||
// 4. Load, copy, and store data
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Step-by-Step Execution
|
||||
|
||||
1. **Tensor View Creation**:
|
||||
```cpp
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
```
|
||||
- Creates views for both input and output tensors
|
||||
- Specifies vectorized access with `ThreadTile_N` elements per load
|
||||
|
||||
2. **Tile Window Creation**:
|
||||
```cpp
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
```
|
||||
- Creates windows into specific tiles of the tensors
|
||||
- Each block processes one tile starting at `{iM, 0}`
|
||||
- Tile distribution determines how threads access data
|
||||
|
||||
3. **N-Dimension Iteration**:
|
||||
```cpp
|
||||
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
```
|
||||
- If tensor N dimension > Block_Tile_N, multiple iterations are needed
|
||||
- Each iteration processes one tile along N dimension
|
||||
|
||||
4. **Load-Store Operations**:
|
||||
```cpp
|
||||
dram_reg_tile dram_tile;
|
||||
load_tile(dram_tile, x_window); // Load from global memory to registers
|
||||
store_tile(y_window, dram_tile); // Store from registers to global memory
|
||||
move_tile_window(x_window, {0, S::Block_Tile_N}); // Move to next N tile
|
||||
move_tile_window(y_window, {0, S::Block_Tile_N});
|
||||
```
|
||||
|
||||
### How Load/Store Works
|
||||
|
||||
1. **Load Tile**:
|
||||
- Each thread loads its assigned elements based on tile distribution
|
||||
- Vectorized loads enable efficient memory bandwidth utilization
|
||||
- Data is distributed to per-thread register buffers
|
||||
|
||||
2. **Store Tile**:
|
||||
- Each thread writes its assigned elements back to global memory
|
||||
- Maintains the same distribution pattern as load
|
||||
|
||||
3. **Tile Window Movement**:
|
||||
- Moves the window to the next tile along N dimension
|
||||
- Enables processing of large tensors that don't fit in one tile
|
||||
|
||||
## Memory Access Patterns
|
||||
|
||||
### Vectorized Access
|
||||
- Enabled by specifying vector length in tensor views
|
||||
- Each thread loads/stores multiple contiguous elements in one operation
|
||||
- Improves memory bandwidth utilization
|
||||
|
||||
### Thread Distribution
|
||||
- Tile distribution encoding determines which threads access which elements
|
||||
- Ensures all threads participate and no data is missed
|
||||
- Enables memory coalescing for optimal performance
|
||||
|
||||
### Coordinate Transform (Embed)
|
||||
- Maps multi-dimensional tensor indices to linear memory addresses
|
||||
- Handles stride calculations automatically
|
||||
- Enables efficient access to non-contiguous memory layouts
|
||||
148
tutorial/ck_tile/00_copy_kernel/copy_basic.cpp
Normal file
148
tutorial/ck_tile/00_copy_kernel/copy_basic.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <cstring>
|
||||
#include "copy_basic.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "128", "m dimension")
|
||||
.insert("n", "8", "n dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision(fp16 or fp32)")
|
||||
.insert("warmup", "50", "cold iter")
|
||||
.insert("repeat", "100", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Create host tensors
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}); // input matrix
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}); // reference output matrix
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}); // device output matrix
|
||||
|
||||
// Initialize input data with increasing values
|
||||
ck_tile::half_t value = 1;
|
||||
for(int i = 0; i < m; i++)
|
||||
{
|
||||
value = 1;
|
||||
for(int j = 0; j < n; j++)
|
||||
{
|
||||
x_host(i, j) = value++;
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
// Define tile configuration
|
||||
using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N
|
||||
using WaveTile = ck_tile::sequence<64, 4>; // per-wave tile size along M and N dimension
|
||||
using BlockWaves = ck_tile::sequence<4, 1>; // number of waves per block along M and N dimension
|
||||
using BlockTile = ck_tile::sequence<512, 4>; // per-block tile size along M and N dimension
|
||||
|
||||
// Calculate grid size
|
||||
ck_tile::index_t kGridSize =
|
||||
ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl;
|
||||
|
||||
// Define kernel types
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, ThreadTile>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Policy = ck_tile::TileCopyPolicy<Problem>;
|
||||
using Kernel = ck_tile::ElementWiseTileCopyKernel<Problem, Policy>; // operates on element by
|
||||
// element basis.
|
||||
|
||||
// We also implement two variations of the copy kernel:
|
||||
// 1. TileCopyKernel: This is the basic copy kernel that operates on tile by tile basis.
|
||||
// 2. TileCopyKernel_LDS: This is the copy kernel that operates on tile by tile basis and uses
|
||||
// the LDS. using Kernel = ck_tile::TileCopyKernel<Problem, Policy>; using Kernel =
|
||||
// ck_tile::TileCopyKernel_LDS<Problem, Policy>;
|
||||
|
||||
auto blockSize = Kernel::BlockSize();
|
||||
|
||||
// Print configuration information
|
||||
std::cout << "block size (number of threads per block) " << blockSize << std::endl;
|
||||
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
|
||||
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
|
||||
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "block tile (number of elements per block) " << BlockTile::at(ck_tile::number<0>{})
|
||||
<< " " << BlockTile::at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{})
|
||||
<< " " << WaveTile::at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "thread tile (number of elements per thread) "
|
||||
<< ThreadTile::at(ck_tile::number<0>{}) << " " << ThreadTile::at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> ("
|
||||
<< Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M
|
||||
<< ")" << std::endl;
|
||||
std::cout << "WaveRepetitionPerBlock_N = " << Shape::WaveRepetitionPerBlock_N << " --> ("
|
||||
<< Shape::Block_Tile_N << "/" << Shape::Waves_Per_Block_N << "*" << Shape::Wave_Tile_N
|
||||
<< ")" << std::endl;
|
||||
|
||||
// Launch kernel
|
||||
float ave_time =
|
||||
launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{},
|
||||
kGridSize,
|
||||
blockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
// Calculate and print performance metrics
|
||||
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// Copy results back to host
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
// Use exact equality (tolerance = 0) for copy operations since copy should be exact
|
||||
pass = ck_tile::check_err(y_host_dev, x_host, "Error: Copy operation failed!", 0.0, 0.0);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
// Print results for debugging
|
||||
// std::cout << "Input matrix (x_host):" << std::endl;
|
||||
// std::cout << x_host << std::endl;
|
||||
// std::cout << "Output matrix (y_host_dev):" << std::endl;
|
||||
// std::cout << y_host_dev << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
if(arg_parser.get_str("prec") == "fp16")
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
else
|
||||
return run<float>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
376
tutorial/ck_tile/00_copy_kernel/copy_basic.hpp
Normal file
376
tutorial/ck_tile/00_copy_kernel/copy_basic.hpp
Normal file
@@ -0,0 +1,376 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Tile copy shape configuration
|
||||
*
|
||||
* @tparam BlockWaves Number of waves along seq<M, N>
|
||||
* @tparam BlockTile Block size, seq<M, N>
|
||||
* @tparam WaveTile Wave size, seq<M, N>
|
||||
* @tparam ThreadTile Contiguous elements per thread along seq<M, N>
|
||||
*/
|
||||
template <typename BlockWaves, typename BlockTile, typename WaveTile, typename ThreadTile>
|
||||
struct TileCopyShape
|
||||
{
|
||||
// ThreadTile dimensions for memory operations
|
||||
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
||||
|
||||
// Wave tile dimensions
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
|
||||
|
||||
// Block tile dimensions
|
||||
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_Tile_N = BlockTile::at(number<1>{});
|
||||
|
||||
// Waves per block configuration
|
||||
static constexpr index_t Waves_Per_Block_M = BlockWaves::at(number<0>{});
|
||||
static constexpr index_t Waves_Per_Block_N = BlockWaves::at(number<1>{});
|
||||
|
||||
// Calculate wave repetition to cover entire block tile
|
||||
static constexpr index_t WaveRepetitionPerBlock_M =
|
||||
Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M);
|
||||
static constexpr index_t WaveRepetitionPerBlock_N =
|
||||
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
|
||||
|
||||
// Hardware configuration
|
||||
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
|
||||
|
||||
// Configuration validation
|
||||
static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive");
|
||||
static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive");
|
||||
static_assert(ThreadTile_M > 0 && ThreadTile_N > 0, "ThreadTile dimensions must be positive");
|
||||
static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0,
|
||||
"Waves per block must be positive");
|
||||
static_assert(Waves_Per_Block_M * Wave_Tile_M > 0,
|
||||
"Invalid wave configuration for M dimension");
|
||||
static_assert(Waves_Per_Block_N * Wave_Tile_N > 0,
|
||||
"Invalid wave configuration for N dimension");
|
||||
|
||||
// Ensure wave tile dimensions align with wave size
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
|
||||
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Problem definition for tile copy operation
|
||||
*/
|
||||
template <typename XDataType_, typename BlockShape_>
|
||||
struct TileCopyProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Policy for tile copy operation
|
||||
*/
|
||||
template <typename Problem_>
|
||||
struct TileCopyPolicy
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
/**
|
||||
* @brief Create DRAM distribution for optimal memory access
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
constexpr index_t wave_size = S::WaveSize;
|
||||
constexpr index_t block_size = S::BlockSize;
|
||||
|
||||
// Distribution calculation to ensure all threads participate
|
||||
constexpr index_t N1 = S::ThreadTile_N; // Elements per thread along N
|
||||
constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N
|
||||
|
||||
constexpr index_t M2 = wave_size / N0; // Threads per wave along M
|
||||
constexpr index_t M1 = block_size / wave_size; // Waves possible along M
|
||||
constexpr index_t M0 = S::Block_Tile_M / (M1 * M2); // Wave iterations along M
|
||||
|
||||
// Validate complete coverage
|
||||
static_assert(M0 * M1 * M2 * N0 * N1 == S::Block_Tile_M * S::Block_Tile_N,
|
||||
"Tile distribution must cover entire block tile");
|
||||
|
||||
constexpr auto outer_encoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
return make_static_tile_distribution(outer_encoding);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Direct copy kernel from global memory to global memory
|
||||
*/
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct TileCopyKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Calculate tile block origin and validate bounds
|
||||
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
|
||||
// This saves VGPR usage by avoiding per-thread storage of the same value
|
||||
const auto tile_block_origin_m =
|
||||
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
|
||||
if(tile_block_origin_m >= M)
|
||||
{
|
||||
return; // Early exit for out-of-bounds blocks
|
||||
}
|
||||
|
||||
// Create tensor views for input and output
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
// Create tile windows with DRAM distribution
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{tile_block_origin_m, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{tile_block_origin_m, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
// Calculate iterations needed to cover N dimension
|
||||
// Note: This kernel uses data parallelism only in the M dimension.
|
||||
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
|
||||
// This design choice is for simplicity and to avoid complex tile distribution.
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
|
||||
|
||||
// Get tile distribution for register tensor
|
||||
auto DramTileDist = x_window.get_tile_distribution();
|
||||
using dram_reg_tile = decltype(make_static_distributed_tensor<XDataType>(DramTileDist));
|
||||
|
||||
// Main copy loop - processes N dimension tiles sequentially within each block
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
dram_reg_tile dram_tile;
|
||||
|
||||
// Direct copy implementation
|
||||
load_tile(dram_tile, x_window);
|
||||
store_tile(y_window, dram_tile);
|
||||
|
||||
// Move to next N tile
|
||||
move_tile_window(x_window, {0, S::Block_Tile_N});
|
||||
move_tile_window(y_window, {0, S::Block_Tile_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Element-wise copy kernel for data transformation scenarios
|
||||
*
|
||||
* This kernel performs element-wise copy operations, allowing for data transformation
|
||||
* during the copy process. Useful when data needs to be processed or converted
|
||||
* between different formats.
|
||||
*/
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct ElementWiseTileCopyKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
if(ck_tile::is_wave32())
|
||||
{
|
||||
return kBlockSize / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kBlockSize;
|
||||
}
|
||||
}
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Calculate block origin and validate bounds
|
||||
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
|
||||
// This saves VGPR usage by avoiding per-thread storage of the same value
|
||||
const auto tile_block_origin_m =
|
||||
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
|
||||
if(tile_block_origin_m >= M)
|
||||
{
|
||||
return; // Early exit for out-of-bounds blocks
|
||||
}
|
||||
|
||||
// Create tensor views for input and output
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
// Create tile windows with DRAM distribution
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{tile_block_origin_m, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{tile_block_origin_m, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
// Calculate iterations needed to cover N dimension
|
||||
// Note: This kernel uses data parallelism only in the M dimension.
|
||||
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
|
||||
// This design choice is for simplicity and to avoid complex tile distribution.
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
|
||||
|
||||
// Main element-wise copy loop - processes N dimension tiles sequentially within each block
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
// Element-wise copy implementation for data transformation
|
||||
const auto xa = load_tile(x_window);
|
||||
auto y_compute = load_tile(y_window);
|
||||
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
const auto x = ck_tile::type_convert<XDataType>(xa[i_j_idx]);
|
||||
y_compute(i_j_idx) = x;
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_window, y_compute);
|
||||
|
||||
// Move to next N tile
|
||||
move_tile_window(x_window, {0, S::Block_Tile_N});
|
||||
move_tile_window(y_window, {0, S::Block_Tile_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief LDS-based copy kernel for data processing scenarios
|
||||
*
|
||||
* This kernel copies data from global memory to LDS and then to global memory,
|
||||
* useful when data needs to be processed or transformed during the copy operation.
|
||||
*/
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct TileCopyKernel_LDS
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Calculate block origin and validate bounds
|
||||
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
|
||||
// This saves VGPR usage by avoiding per-thread storage of the same value
|
||||
const auto tile_block_origin_m =
|
||||
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
|
||||
if(tile_block_origin_m >= M)
|
||||
{
|
||||
return; // Early exit for out-of-bounds blocks
|
||||
}
|
||||
|
||||
// LDS buffer allocation
|
||||
__shared__ XDataType x_lds_buffer[S::Block_Tile_Mmake * S::Block_Tile_N];
|
||||
|
||||
// LDS tensor descriptor and view
|
||||
const auto x_lds_descriptor =
|
||||
make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
make_tuple(S::Block_Tile_N, 1),
|
||||
number<S::ThreadTile_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto x_lds_view = make_tensor_view<address_space_enum::lds>(x_lds_buffer, x_lds_descriptor);
|
||||
|
||||
// LDS windows with different distributions for optimal access patterns
|
||||
auto x_lds_write_window =
|
||||
make_tile_window(x_lds_view, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {0, 0});
|
||||
|
||||
auto x_lds_read_window = make_tile_window(x_lds_view,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{0, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
// Global memory tensor views
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::ThreadTile_N>{}, number<1>{});
|
||||
|
||||
// Global memory tile windows
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(S::Block_Tile_M, S::Block_Tile_N),
|
||||
{tile_block_origin_m, 0},
|
||||
Policy::template MakeDRAMDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(
|
||||
y_m_n, make_tuple(S::Block_Tile_M, S::Block_Tile_N), {tile_block_origin_m, 0});
|
||||
|
||||
// Calculate iterations needed to cover N dimension
|
||||
// Note: This kernel uses data parallelism only in the M dimension.
|
||||
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
|
||||
// This design choice is for simplicity and to avoid complex tile distribution.
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
|
||||
|
||||
// Main copy loop with LDS staging - processes N dimension tiles sequentially within each
|
||||
// block
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
// Global memory to LDS
|
||||
auto dram_tile = load_tile(x_window);
|
||||
store_tile(x_lds_write_window, dram_tile);
|
||||
|
||||
// Synchronize LDS access
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global memory
|
||||
auto lds_tile = load_tile(x_lds_read_window);
|
||||
store_tile(y_window, lds_tile);
|
||||
|
||||
// Move to next N tile
|
||||
move_tile_window(x_window, {0, S::Block_Tile_N});
|
||||
move_tile_window(y_window, {0, S::Block_Tile_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
53
tutorial/ck_tile/00_copy_kernel/test_tile_example.sh
Executable file
53
tutorial/ck_tile/00_copy_kernel/test_tile_example.sh
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BIN="${BIN:-../../../build/bin/tile_tutorial_copy_kernel}"
|
||||
WARMUP="${WARMUP:-20}"
|
||||
REPEAT="${REPEAT:-100}"
|
||||
VALIDATE="${VALIDATE:-1}"
|
||||
|
||||
MS=(128 256 512 1024)
|
||||
NS=(64 256 1024 2048 4096)
|
||||
PRECS=(fp16 fp32)
|
||||
|
||||
echo "Using BIN=$BIN"
|
||||
echo "WARMUP=$WARMUP REPEAT=$REPEAT VALIDATE=$VALIDATE"
|
||||
|
||||
failures=0
|
||||
|
||||
for prec in "${PRECS[@]}"; do
|
||||
for m in "${MS[@]}"; do
|
||||
for n in "${NS[@]}"; do
|
||||
echo "=============================================="
|
||||
echo "Running: prec=$prec m=$m n=$n"
|
||||
set +e
|
||||
out="$("$BIN" -prec="$prec" -m="$m" -n="$n" -warmup="$WARMUP" -repeat="$REPEAT" -v="$VALIDATE" 2>&1)"
|
||||
rc=$?
|
||||
set -e
|
||||
|
||||
echo "$out"
|
||||
if [[ $rc -ne 0 ]]; then
|
||||
echo "RUN ERROR (rc=$rc) for m=$m n=$n prec=$prec"
|
||||
((failures++)) || true
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ "$VALIDATE" == "1" ]]; then
|
||||
if ! grep -q "valid:y" <<<"$out"; then
|
||||
echo "VALIDATION FAILED for m=$m n=$n prec=$prec"
|
||||
((failures++)) || true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "=============================================="
|
||||
if [[ $failures -eq 0 ]]; then
|
||||
echo "All runs passed"
|
||||
else
|
||||
echo "$failures runs failed"
|
||||
fi
|
||||
589
tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
589
tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
|
||||
## Overview
|
||||
|
||||
The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates:
|
||||
1. **Data movement** from DRAM → Registers → LDS
|
||||
2. **GEMM computation** using data in LDS
|
||||
3. **Iteration** over the K dimension when needed
|
||||
|
||||
This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C.
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Problem and Policy
|
||||
|
||||
Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern:
|
||||
|
||||
### Problem: `PracticeGemmBlockPipelineProblem`
|
||||
Contains:
|
||||
- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType`
|
||||
- **Shape information**: `BlockTile` and `WaveTile` dimensions
|
||||
|
||||
### Policy: `PracticeGemmBlockPolicy`
|
||||
Contains strategies for:
|
||||
1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`)
|
||||
- Defines how 256 threads in a block map to elements of a block tile
|
||||
- Each thread knows which elements to load/store from DRAM to its registers
|
||||
- We'll cover tile distribution construction in detail later
|
||||
|
||||
2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`)
|
||||
- Describes how data is logically organized in Local Data Share (LDS)
|
||||
- Optimizes for bank conflict avoidance and efficient access patterns
|
||||
- We'll cover LDS descriptor construction in detail later
|
||||
|
||||
3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`)
|
||||
- Returns the warp-level GEMM implementation
|
||||
|
||||
---
|
||||
|
||||
## Inputs and Outputs
|
||||
|
||||
```cpp
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
```
|
||||
|
||||
### Inputs:
|
||||
- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock)
|
||||
- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock)
|
||||
- `num_loop`: Number of iterations along K dimension
|
||||
- `p_smem`: Pointer to shared memory (LDS)
|
||||
|
||||
### Output:
|
||||
- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs)
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Walkthrough
|
||||
|
||||
### Step 1: Create LDS Tensor Views
|
||||
|
||||
```cpp
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// B tile in LDS (placed after A in shared memory)
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We partition the shared memory (`p_smem`) into two regions: one for A, one for B
|
||||
- We create **tensor views** over these LDS regions using descriptors from the policy
|
||||
- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Shared Memory (LDS):
|
||||
┌─────────────────────┬─────────────────────┐
|
||||
│ A Block Tile │ B Block Tile │
|
||||
│ (256×32 fp16) │ (128×32 fp16) │
|
||||
└─────────────────────┴─────────────────────┘
|
||||
↑ ↑
|
||||
p_a_lds p_b_lds
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Create Tile Windows for Data Movement
|
||||
|
||||
We create **6 tile windows** for different purposes:
|
||||
|
||||
#### 2a. DRAM → Registers (Load from DRAM)
|
||||
|
||||
```cpp
|
||||
auto a_copy_dram_window = make_tile_window(
|
||||
a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>()); // ← Tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `a_copy_dram_window` is a `tile_window_with_static_distribution`
|
||||
- The **tile distribution** tells each thread which elements to load from DRAM
|
||||
- This window will **slide along the K dimension** in the loop
|
||||
|
||||
#### 2b. Registers → LDS (Store to LDS)
|
||||
|
||||
```cpp
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // 256×32
|
||||
{0, 0}, // Origin at (0, 0) in LDS
|
||||
a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Uses the **same tile distribution** as `a_copy_dram_window`
|
||||
- This ensures each thread stores to LDS in the same pattern it loaded from DRAM
|
||||
- Origin is always `{0, 0}` because LDS is reused for each K iteration
|
||||
|
||||
#### 2c. LDS → Registers (GEMM Input)
|
||||
|
||||
```cpp
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0}); // No tile distribution!
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- This is a `tile_window_with_static_lengths` (no explicit distribution)
|
||||
- Used as input to the warp-level GEMM
|
||||
- The warp GEMM will handle its own thread mapping internally
|
||||
|
||||
**Similar windows are created for B:**
|
||||
- `b_copy_dram_window`: Load B from DRAM
|
||||
- `b_copy_lds_window`: Store B to LDS
|
||||
- `b_lds_gemm_window`: Read B from LDS for GEMM
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Create Distributed Tensors (VGPRs)
|
||||
|
||||
```cpp
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile; // Per-thread registers for A
|
||||
BBlockTile b_block_tile; // Per-thread registers for B
|
||||
```
|
||||
|
||||
#### What is `make_static_distributed_tensor`?
|
||||
|
||||
**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**.
|
||||
|
||||
**Key Properties:**
|
||||
1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers
|
||||
2. **Compile-time sized**: Buffer size determined by tile distribution at compile time
|
||||
3. **Zero-overhead**: All indexing and layout transformations happen at compile time
|
||||
|
||||
**How it works:**
|
||||
|
||||
```cpp
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
// Calculate per-thread storage size from tile distribution
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
// Per-thread register array (VGPRs)
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
```
|
||||
|
||||
**The tile distribution defines:**
|
||||
- **Which elements each thread owns** in the tile
|
||||
- **How many elements** each thread stores (buffer size)
|
||||
- **How elements are laid out** in each thread's registers
|
||||
|
||||
**Concrete Example for 256×32 tile with 256 threads:**
|
||||
|
||||
```
|
||||
Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values)
|
||||
Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values)
|
||||
Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values)
|
||||
...
|
||||
Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values)
|
||||
```
|
||||
|
||||
**Collectively:**
|
||||
- All 256 threads together hold the **entire 256×32 tile** (8192 elements)
|
||||
- Each thread's buffer lives in its **own VGPRs**
|
||||
- No two threads own the same element
|
||||
|
||||
**Distributed Ownership Analogy:**
|
||||
Think of a tile as a **jigsaw puzzle**:
|
||||
- The **tile distribution** is the cutting pattern
|
||||
- Each **thread** gets one puzzle piece (its slice)
|
||||
- Each **`static_distributed_tensor`** is a box holding all pieces
|
||||
- Each thread's **`thread_buf_`** is its individual piece in its own registers
|
||||
|
||||
---
|
||||
|
||||
### Step 4: The GEMM Loop
|
||||
|
||||
```cpp
|
||||
// Initialize C accumulator to zero
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
index_t iCounter = num_loop; // Number of K iterations
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// 1. Load from DRAM to registers
|
||||
a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs
|
||||
b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs
|
||||
|
||||
// 2. Move windows for next iteration
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32)
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32)
|
||||
|
||||
// 3. Store from registers to LDS
|
||||
store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS
|
||||
|
||||
// 4. Synchronize threads (ensure all data is in LDS)
|
||||
block_sync_lds();
|
||||
|
||||
// 5. Compute GEMM using data in LDS
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// 6. Synchronize threads (before overwriting LDS in next iteration)
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile; // Return accumulated result in registers
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Detailed Loop Breakdown
|
||||
|
||||
### Phase 1: Load (DRAM → VGPRs)
|
||||
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution)
|
||||
2. Data is loaded into **per-thread registers** (VGPRs)
|
||||
3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once)
|
||||
|
||||
**Example for Thread 0:**
|
||||
```
|
||||
Thread 0 loads:
|
||||
A[0,0:7] (8 fp16 values, one vector load)
|
||||
A[1,0:7] (8 fp16 values, one vector load)
|
||||
...
|
||||
```
|
||||
|
||||
### Phase 2: Move Windows
|
||||
|
||||
```cpp
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example)
|
||||
- This prepares for the next K iteration
|
||||
- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ...
|
||||
|
||||
**Visualization for Problem Size 512×256×64:**
|
||||
```
|
||||
Matrix A (512×64):
|
||||
┌─────────────────────────────────────┐
|
||||
│ Block 0: rows 0-255 │
|
||||
│ ┌──────────┬──────────┐ │
|
||||
│ │ K=0:31 │ K=32:63 │ │ ← Window slides right
|
||||
│ │ Iter 0 │ Iter 1 │ │
|
||||
│ └──────────┴──────────┘ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Phase 3: Store (VGPRs → LDS)
|
||||
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Each thread writes **its elements** from registers to LDS
|
||||
2. Uses the **same distribution** as the DRAM load
|
||||
3. Data is now in **shared memory**, accessible to all threads in the block
|
||||
|
||||
**Why this step?**
|
||||
- GEMM computation needs **all threads** to access **all data**
|
||||
- Registers are per-thread; LDS is shared across the block
|
||||
- LDS acts as a "staging area" for collaborative computation
|
||||
|
||||
### Phase 4: Synchronize
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- All threads in the block **wait** until everyone has finished storing to LDS
|
||||
- Ensures no thread starts reading from LDS before all writes are complete
|
||||
- Critical for correctness!
|
||||
|
||||
### Phase 5: GEMM Computation
|
||||
|
||||
```cpp
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. The warp-level GEMM reads data from LDS
|
||||
2. Performs matrix multiplication using MFMA instructions
|
||||
3. Accumulates results into `c_block_tile` (in registers)
|
||||
|
||||
**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results.
|
||||
|
||||
### Phase 6: Synchronize Again
|
||||
|
||||
```cpp
|
||||
block_sync_lds();
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- Ensures all threads have finished reading from LDS
|
||||
- Safe to overwrite LDS in the next iteration
|
||||
|
||||
---
|
||||
|
||||
## Memory Flow Diagram
|
||||
|
||||
```
|
||||
Iteration 0 (K=0:31):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 0:31] │ │ thread) │ │ │
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (VGPRs) │
|
||||
└──────────┘
|
||||
|
||||
Iteration 1 (K=32:63):
|
||||
┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐
|
||||
│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │
|
||||
│ A[0:255,│ │ (per- │ │ A_block │
|
||||
│ 32:63] │ │ thread) │ │ (reused)│
|
||||
└─────────┘ └──────────┘ └─────────┘
|
||||
│
|
||||
│ block_gemm
|
||||
↓
|
||||
┌──────────┐
|
||||
│ c_block_ │
|
||||
│ tile │
|
||||
│ (accum.) │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example: Problem Size 512×256×64
|
||||
|
||||
### Block 0 Computation
|
||||
|
||||
**Input:**
|
||||
- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially
|
||||
- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed)
|
||||
- `num_loop`: 2 (since K=64, KPerBlock=32)
|
||||
|
||||
**Iteration 0:**
|
||||
1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63]
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T`
|
||||
|
||||
**Iteration 1:**
|
||||
1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs
|
||||
2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends)
|
||||
3. Store to LDS
|
||||
4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T`
|
||||
|
||||
**Output:**
|
||||
- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. Tile Distribution
|
||||
- **Maps threads to data elements** for load/store operations
|
||||
- Each thread knows exactly which elements it's responsible for
|
||||
- Enables **parallel, vectorized** memory access
|
||||
- **Same distribution** used for DRAM load and LDS store
|
||||
|
||||
### 2. Static Distributed Tensor
|
||||
- **Per-thread register storage** (VGPRs)
|
||||
- Each thread owns a **different slice** of the tile
|
||||
- **Compile-time sized** for zero-overhead abstraction
|
||||
- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile`
|
||||
|
||||
### 3. Tile Window Movement
|
||||
- Windows **slide** over larger tensors
|
||||
- Enables iteration over the K dimension
|
||||
- `move_tile_window(window, step)` updates the origin
|
||||
|
||||
### 4. LDS as Staging Area
|
||||
- **Shared memory** accessible to all threads in a block
|
||||
- Required because GEMM needs all threads to access all data
|
||||
- **Reused** across K iterations (same LDS buffer)
|
||||
|
||||
### 5. Synchronization
|
||||
- `block_sync_lds()` ensures memory consistency
|
||||
- **Before GEMM**: All stores to LDS are complete
|
||||
- **After GEMM**: All reads from LDS are complete
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `static_distributed_tensor` Mechanics
|
||||
|
||||
### How Tile Distribution Creates Per-Thread Storage
|
||||
|
||||
When you call:
|
||||
```cpp
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<fp16_t>(ABlockTileDistr{}));
|
||||
ABlockTile a_block_tile;
|
||||
```
|
||||
|
||||
**Step 1: Extract Thread Tensor Descriptor**
|
||||
|
||||
The tile distribution contains a `ys_to_d_descriptor` that maps:
|
||||
- **Y dimensions** (logical tile coordinates, e.g., M, K)
|
||||
- **D dimension** (per-thread register index, linearized)
|
||||
|
||||
```cpp
|
||||
using ThreadTensorDesc =
|
||||
decltype(StaticTileDistribution{}.get_ys_to_d_descriptor());
|
||||
```
|
||||
|
||||
**Step 2: Calculate Per-Thread Buffer Size**
|
||||
|
||||
```cpp
|
||||
static constexpr index_t kThreadElementSpaceSize =
|
||||
ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
- 256×32 tile distributed across 256 threads
|
||||
- Each thread owns 32 elements (one row)
|
||||
- `thread_buffer_size = 32` (for PackedSize=1)
|
||||
|
||||
**Step 3: Allocate Thread Buffer**
|
||||
|
||||
```cpp
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
```
|
||||
|
||||
This is essentially:
|
||||
```cpp
|
||||
fp16_t data[32]; // Per-thread register array (VGPRs)
|
||||
```
|
||||
|
||||
### Usage in Load/Store Operations
|
||||
|
||||
**Load from DRAM:**
|
||||
```cpp
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread queries the tile distribution: "Which elements do I own?"
|
||||
2. Thread 0 learns it owns A[0,0:31]
|
||||
3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]`
|
||||
4. All 256 threads do this **in parallel**
|
||||
|
||||
**Store to LDS:**
|
||||
```cpp
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
```
|
||||
|
||||
What happens internally:
|
||||
1. Each thread reads from its `a_block_tile.thread_buf_`
|
||||
2. Thread 0 writes A[0,0:31] from its registers to LDS
|
||||
3. All 256 threads do this **in parallel**
|
||||
4. After `block_sync_lds()`, the entire tile is in shared LDS
|
||||
|
||||
### Distributed Indexing
|
||||
|
||||
The `static_distributed_tensor` supports compile-time indexing:
|
||||
|
||||
```cpp
|
||||
// Access using distributed indices
|
||||
auto value = a_block_tile(tile_distributed_index<i, j>{});
|
||||
```
|
||||
|
||||
Internally:
|
||||
1. Convert distributed index → Y index (logical tile coordinates)
|
||||
2. Calculate buffer offset using `ThreadTensorDesc`
|
||||
3. Access `thread_buf_[offset]`
|
||||
|
||||
All of this happens **at compile time** with zero runtime overhead!
|
||||
|
||||
### Why This Design?
|
||||
|
||||
**Benefits:**
|
||||
1. **Parallel Memory Access**: All threads load/store simultaneously
|
||||
2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once)
|
||||
3. **Zero Overhead**: All indexing resolved at compile time
|
||||
4. **Type Safety**: Distribution mismatch caught at compile time
|
||||
5. **Register Pressure**: Compiler knows exact VGPR usage
|
||||
|
||||
**Trade-offs:**
|
||||
- Requires compile-time tile sizes
|
||||
- Distribution must be static
|
||||
- More complex type system
|
||||
|
||||
### Memory Hierarchy Summary
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) │
|
||||
│ Full matrices A, B, C │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ load_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Per-Thread Registers) │
|
||||
│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │
|
||||
│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │
|
||||
│ ... │
|
||||
│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │
|
||||
│ │
|
||||
│ ← static_distributed_tensor manages this distribution │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ store_tile (parallel, vectorized)
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) │
|
||||
│ Entire block tile (256×32) │
|
||||
│ Accessible to all threads in block │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time.
|
||||
|
||||
|
||||
|
||||
7
tutorial/ck_tile/01_naive_gemm/CMakeLists.txt
Normal file
7
tutorial/ck_tile/01_naive_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp)
|
||||
|
||||
target_compile_options(tile_tutorial_naive_gemm PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
add_dependencies(tutorials tile_tutorial_naive_gemm)
|
||||
618
tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
618
tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md
Normal file
@@ -0,0 +1,618 @@
|
||||
# Host-Level Pipeline: Orchestrating Block-Level GEMM
|
||||
|
||||
This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation.
|
||||
|
||||
## Overview
|
||||
|
||||
The host-level pipeline is responsible for:
|
||||
1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C
|
||||
2. **Block-to-tile mapping**: Assigning each thread block to a specific tile
|
||||
3. **Creating tile windows**: Establishing sliding windows over tensor views
|
||||
4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM
|
||||
5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
// 1. Calculate problem dimensions and tile coverage
|
||||
// 2. Map thread block to tile coordinates
|
||||
// 3. Create tile windows over A and B
|
||||
// 4. Call block-level pipeline to compute
|
||||
// 5. Store result to C
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Calculate Problem Dimensions and Tile Coverage
|
||||
|
||||
```cpp
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{}); // 256
|
||||
const auto NPerBlock = BlockTile::at(number<1>{}); // 128
|
||||
const auto KPerBlock = BlockTile::at(number<2>{}); // 32
|
||||
|
||||
// Number of block tiles needed to cover C matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Extract problem dimensions** from tensor descriptors:
|
||||
- `M = 512`: Rows in A and C
|
||||
- `N = 256`: Columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
|
||||
2. **Get block tile sizes** from the `BlockTile` configuration:
|
||||
- `MPerBlock = 256`: Each block processes 256 rows
|
||||
- `NPerBlock = 128`: Each block processes 128 columns
|
||||
- `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration
|
||||
|
||||
3. **Calculate tile coverage**:
|
||||
- `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction
|
||||
- `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction
|
||||
- **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**!
|
||||
|
||||
### Visual Representation:
|
||||
|
||||
```
|
||||
Matrix C (512 × 256):
|
||||
┌──────────────────────┬──────────────────────┐
|
||||
│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 0 │ Block 1 │
|
||||
│ │ │
|
||||
├──────────────────────┼──────────────────────┤
|
||||
│ Tile (1,0) │ Tile (1,1) │
|
||||
│ 256×128 │ 256×128 │
|
||||
│ Block 2 │ Block 3 │
|
||||
│ │ │
|
||||
└──────────────────────┴──────────────────────┘
|
||||
↑
|
||||
num_tile_m = 2
|
||||
|
||||
Total blocks needed = 2 × 2 = 4 blocks
|
||||
|
||||
Each block computes one 256×128 tile of the output matrix C.
|
||||
```
|
||||
|
||||
### How Blocks Cover Matrices A and B:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬──────┐ ┌─────────────┬──────┐
|
||||
│ Block 0,2 │ K │ │ Block 0,1 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 0-255 │ │ │ 0-127 │ │
|
||||
├─────────────┼──────┤ ├─────────────┼──────┤
|
||||
│ Block 1,3 │ K │ │ Block 2,3 │ K │
|
||||
│ uses rows │ → │ │ uses rows │ → │
|
||||
│ 256-511 │ │ │ 128-255 │ │
|
||||
└─────────────┴──────┘ └─────────────┴──────┘
|
||||
256 rows 64 cols 128 rows 64 cols
|
||||
|
||||
Each block needs to iterate over K dimension (64/32 = 2 iterations)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Map Thread Block to Tile Coordinates
|
||||
|
||||
```cpp
|
||||
// Get block id (0 to total_blocks - 1)
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
// Map block id to 2D tile coordinates
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate
|
||||
const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates.
|
||||
|
||||
### The `MakeBlock2TileMap` Function:
|
||||
|
||||
```cpp
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
// Create a merge transform: (N0, M0) → linear index
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
// Convert linear block_id back to 2D coordinates
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
// Return (m_idx, n_idx) - note the swap!
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### In Our Example (2×2 Grid):
|
||||
|
||||
```cpp
|
||||
// Block 0:
|
||||
id_block = 0
|
||||
tile_id = block2tile(0) = (0, 0) // Top-left tile
|
||||
tile_id_m = 0, tile_id_n = 0
|
||||
|
||||
// Block 1:
|
||||
id_block = 1
|
||||
tile_id = block2tile(1) = (1, 0) // Bottom-left tile
|
||||
tile_id_m = 1, tile_id_n = 0
|
||||
|
||||
// Block 2:
|
||||
id_block = 2
|
||||
tile_id = block2tile(2) = (0, 1) // Top-right tile
|
||||
tile_id_m = 0, tile_id_n = 1
|
||||
|
||||
// Block 3:
|
||||
id_block = 3
|
||||
tile_id = block2tile(3) = (1, 1) // Bottom-right tile
|
||||
tile_id_m = 1, tile_id_n = 1
|
||||
```
|
||||
|
||||
**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing!
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Calculate Tile Origin and Create Tile Windows
|
||||
|
||||
```cpp
|
||||
// Calculate the starting position of this tile in the global matrix
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128
|
||||
|
||||
// Create tile windows over A and B tensor views
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, // Tensor view over A
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), // Window size: 256×32
|
||||
{tile_origin_m, 0} // Origin: varies by block
|
||||
);
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, // Tensor view over B
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), // Window size: 128×32
|
||||
{tile_origin_n, 0} // Origin: varies by block
|
||||
);
|
||||
```
|
||||
|
||||
### Tile Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0 (Tile 0,0):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 1 (Tile 1,0):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 0 * 128 = 0
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (0, 0) → covers B rows 0-127
|
||||
|
||||
// Block 2 (Tile 0,1):
|
||||
tile_origin_m = 0 * 256 = 0
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (0, 0) → covers A rows 0-255
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
|
||||
// Block 3 (Tile 1,1):
|
||||
tile_origin_m = 1 * 256 = 256
|
||||
tile_origin_n = 1 * 128 = 128
|
||||
a_block_window origin: (256, 0) → covers A rows 256-511
|
||||
b_block_window origin: (128, 0) → covers B rows 128-255
|
||||
```
|
||||
|
||||
### What are Tile Windows?
|
||||
|
||||
A **tile window** is a **sliding window** over a larger tensor view. It:
|
||||
- Defines a **rectangular region** within the tensor
|
||||
- Has a **fixed size** (e.g., 256×32 for A)
|
||||
- Has an **origin** (starting position)
|
||||
- Can be **moved** to access different regions
|
||||
### Visual Representation (Block 0 Example):
|
||||
|
||||
```
|
||||
Matrix A (512 × 64): Matrix B (256 × 64):
|
||||
┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │
|
||||
│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │
|
||||
│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │
|
||||
│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │
|
||||
│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │
|
||||
│ │ │ ├─────────────┼─────────────┤
|
||||
├─────────────┼─────────────┤ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ │ │
|
||||
└─────────────┴─────────────┘ └─────────────┴─────────────┘
|
||||
Origin: (0, 0) Origin: (0, 0)
|
||||
Covers rows 0-255 Covers rows 0-127
|
||||
Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration)
|
||||
```
|
||||
|
||||
**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration.
|
||||
|
||||
### Tile Window Properties:
|
||||
|
||||
```cpp
|
||||
// Tile window structure (conceptual):
|
||||
struct tile_window {
|
||||
TensorView& tensor_view; // Reference to underlying tensor
|
||||
Tuple window_lengths; // Size of the window (256, 32)
|
||||
MultiIndex window_origin; // Starting position (0, 0)
|
||||
|
||||
// Can move the window:
|
||||
void move(MultiIndex step); // Shift window by step
|
||||
|
||||
// Access data through the window:
|
||||
auto load(); // Load data from windowed region
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
### Tile Window Movement: Iterating Over K Dimension
|
||||
|
||||
In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension:
|
||||
|
||||
```
|
||||
Matrix A (512 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K
|
||||
│ ┃ 256×32 ┃ │ ║ 256×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ │ │
|
||||
│ Block 1's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
|
||||
Matrix B (256 × 64) - Block 0's view:
|
||||
┌─────────────┬─────────────┐
|
||||
│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │
|
||||
│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │
|
||||
│ ┃ 128×32 ┃ │ ║ 128×32 ║ │
|
||||
│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │
|
||||
│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │
|
||||
├─────────────┼─────────────┤
|
||||
│ Block 2's │ │
|
||||
│ region │ │
|
||||
└─────────────┴─────────────┘
|
||||
```
|
||||
|
||||
### How Windows Move (Conceptual - handled by block pipeline):
|
||||
|
||||
```cpp
|
||||
// Iteration 0:
|
||||
a_block_window origin: (tile_origin_m, 0) // K columns 0-31
|
||||
b_block_window origin: (tile_origin_n, 0) // K columns 0-31
|
||||
// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31]
|
||||
|
||||
// Move windows to next K position:
|
||||
move_tile_window(a_block_window, {0, 32});
|
||||
move_tile_window(b_block_window, {0, 32});
|
||||
|
||||
// Iteration 1:
|
||||
a_block_window origin: (tile_origin_m, 32) // K columns 32-63
|
||||
b_block_window origin: (tile_origin_n, 32) // K columns 32-63
|
||||
// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63]
|
||||
|
||||
// Final result:
|
||||
// C_tile = C_partial_0 + C_partial_1
|
||||
```
|
||||
|
||||
**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C.
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Delegate to Block-Level Pipeline
|
||||
|
||||
```cpp
|
||||
// Get the block-level pipeline from policy
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
// Calculate number of K iterations needed
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2
|
||||
|
||||
// Allocate shared memory (LDS) for block-level computation
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
|
||||
// Call block-level pipeline to compute C tile
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation
|
||||
2. **Calculate K iterations**: How many times to iterate over the K dimension
|
||||
- In our example: `K=64, KPerBlock=32` → **2 iterations**
|
||||
- Each iteration processes 32 elements of the K dimension
|
||||
- Results are accumulated across iterations
|
||||
|
||||
3. **Allocate shared memory**:
|
||||
- `__shared__` declares memory shared by all threads in the block
|
||||
- `GetStaticLDSSize()` returns the required size in bytes
|
||||
- This memory is used for:
|
||||
- Staging data from DRAM → LDS
|
||||
- Cooperative loading by threads
|
||||
- Fast access during computation
|
||||
|
||||
4. **Execute block pipeline**:
|
||||
- Takes A and B tile windows as input
|
||||
- Performs the GEMM computation: `C_tile = A_tile × B_tile`
|
||||
- Returns result in `c_block_tile` (stored in VGPRs - registers)
|
||||
|
||||
### Memory Hierarchy During Computation:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ DRAM (Global Memory) - Slowest, Largest │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A matrix │ │ B matrix │ │ C matrix │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load ↑ store
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ A_tile │ │ B_tile │ ← Staged here │
|
||||
│ │ (p_smem) │ │ (p_smem) │ │
|
||||
│ └─────────────┘ └─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓ load ↓ load
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ c_block_tile (accumulated result) │ │
|
||||
│ │ Computation happens here using MFMA instructions │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Block Pipeline Responsibilities:
|
||||
|
||||
The block pipeline (called here) will:
|
||||
1. Load A and B tiles from DRAM → LDS (cooperative loading)
|
||||
2. Distribute work among warps
|
||||
3. Each warp loads its portion from LDS → VGPRs
|
||||
4. Perform MFMA operations: `C += A × B`
|
||||
5. Accumulate results in VGPRs
|
||||
6. Return final `c_block_tile` in registers
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Store Results to DRAM
|
||||
|
||||
```cpp
|
||||
// Create a tile window over C for writing results
|
||||
auto c_window = make_tile_window(
|
||||
c_dram, // Tensor view over C
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}), // Window size: 256×128
|
||||
{tile_origin_m, tile_origin_n} // Origin: varies by block
|
||||
);
|
||||
|
||||
// Store computed tile from VGPRs to DRAM
|
||||
store_tile(c_window, c_block_tile);
|
||||
```
|
||||
|
||||
### C Window Origins for Each Block:
|
||||
|
||||
```cpp
|
||||
// Block 0: Writes to top-left tile
|
||||
c_window origin: (0, 0) → writes to C[0:255, 0:127]
|
||||
|
||||
// Block 1: Writes to bottom-left tile
|
||||
c_window origin: (256, 0) → writes to C[256:511, 0:127]
|
||||
|
||||
// Block 2: Writes to top-right tile
|
||||
c_window origin: (0, 128) → writes to C[0:255, 128:255]
|
||||
|
||||
// Block 3: Writes to bottom-right tile
|
||||
c_window origin: (256, 128) → writes to C[256:511, 128:255]
|
||||
```
|
||||
|
||||
### What's Happening:
|
||||
|
||||
1. **Create C tile window**:
|
||||
- Size: 256×128 (matches our block tile size)
|
||||
- Origin: Varies by block - each block writes to its assigned region
|
||||
- This window defines **where** to write the results
|
||||
|
||||
2. **Store tile to DRAM**:
|
||||
- `c_block_tile`: Computed results in VGPRs (registers)
|
||||
- `c_window`: Destination window in DRAM
|
||||
- `store_tile()`: Efficiently writes data from registers → DRAM
|
||||
|
||||
### The `store_tile` Function:
|
||||
|
||||
Recall from our earlier discussion, `store_tile` does:
|
||||
|
||||
```cpp
|
||||
template <typename TileWindow, typename DistributedTensor>
|
||||
void store_tile(TileWindow& tile_window_tmp,
|
||||
const DistributedTensor& dstr_tensor)
|
||||
{
|
||||
// 1. Extract tile distribution from distributed tensor
|
||||
using TileDstr = typename DistributedTensor::TileDistribution;
|
||||
|
||||
// 2. Upgrade simple tile window to one with distribution
|
||||
auto tile_window = make_tile_window(
|
||||
tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
TileDstr{} // Add distribution info
|
||||
);
|
||||
|
||||
// 3. Store using vectorized writes
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
```
|
||||
|
||||
### Memory Flow:
|
||||
|
||||
```
|
||||
VGPRs (Registers) DRAM (Global Memory)
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ c_block_tile │ │ C matrix │
|
||||
│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │
|
||||
│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │
|
||||
│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │
|
||||
│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │
|
||||
│ └───┴───┴───┴───┘ │ │ │ │ │
|
||||
│ Distributed across │ │ └───────────────┘ │
|
||||
│ threads/warps │ │ Origin: (0, 0) │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
|
||||
Each thread writes its portion using vector stores (e.g., float4)
|
||||
```
|
||||
|
||||
### Store Optimization:
|
||||
|
||||
The `store_tile` function:
|
||||
- Uses **vectorized stores** (write multiple elements at once)
|
||||
- Ensures **coalesced memory access** (adjacent threads write adjacent memory)
|
||||
- Respects **tile distribution** (each thread knows what data it owns)
|
||||
- Handles **out-of-bounds** checking (for partial tiles at boundaries)
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow Visualization
|
||||
|
||||
Let's trace the complete flow for **Block 0** (other blocks follow the same pattern):
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Calculate Tile Coverage │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ M=512, N=256, K=64 │ │
|
||||
│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │
|
||||
│ │ num_tile_m = ceil(512/256) = 2 │ │
|
||||
│ │ num_tile_n = ceil(256/128) = 2 │ │
|
||||
│ │ Total blocks needed = 2 × 2 = 4 blocks │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Map Block to Tile (Block 0 example) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Block ID: 0 │ │
|
||||
│ │ Tile coordinates: (0, 0) - top-left tile │ │
|
||||
│ │ Tile origin: (0, 0) │ │
|
||||
│ │ │ │
|
||||
│ │ (Blocks 1,2,3 get different tile coordinates) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Create Tile Windows │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ a_block_window: 256×32 starting at (0,0) over A │ │
|
||||
│ │ b_block_window: 128×32 starting at (0,0) over B │ │
|
||||
│ │ Windows initially cover K columns 0-31 │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Execute Block Pipeline (2 K iterations) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Allocate shared memory (LDS) │ │
|
||||
│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 0 (K=0-31): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │
|
||||
│ │ └─ Move windows: {0, 32} │ │
|
||||
│ │ │ │
|
||||
│ │ K Iteration 1 (K=32-63): │ │
|
||||
│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │
|
||||
│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │
|
||||
│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │
|
||||
│ │ │ │
|
||||
│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 5: Store Results │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Create c_window: 256×128 starting at (0,0) over C │ │
|
||||
│ │ store_tile(c_window, c_block_tile) │ │
|
||||
│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │
|
||||
│ │ │ │
|
||||
│ │ Block 0 writes to C[0:255, 0:127] │ │
|
||||
│ │ (Other blocks write to their respective regions) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
All 4 blocks execute in parallel, each computing its assigned 256×128 tile!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts Summary
|
||||
|
||||
### 1. **Tile Coverage**
|
||||
- Determines how many thread blocks are needed
|
||||
- Each block processes one tile of the output matrix C
|
||||
- Calculated as `ceil(dimension / tile_size)`
|
||||
|
||||
### 2. **Block-to-Tile Mapping**
|
||||
- Maps linear block ID to 2D tile coordinates
|
||||
- Uses column-major ordering for better memory coalescing
|
||||
- Each block knows which tile it's responsible for
|
||||
|
||||
### 3. **Tile Windows**
|
||||
- **Sliding windows** over larger tensor views
|
||||
- Define a rectangular region with fixed size and movable origin
|
||||
- Provide efficient, structured access to tensor data
|
||||
- Can be moved to access different regions (e.g., for K iterations)
|
||||
|
||||
### 4. **Memory Hierarchy**
|
||||
- **DRAM (Global)**: Largest, slowest - stores full matrices
|
||||
- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access
|
||||
- **VGPRs (Registers)**: Smallest, fastest - performs computation
|
||||
|
||||
### 5. **Data Flow**
|
||||
```
|
||||
DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM
|
||||
↑ ↓
|
||||
A, B matrices C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore:
|
||||
- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed
|
||||
- **Warp-level pipeline**: How warps perform MFMA operations
|
||||
- **Memory optimization**: LDS usage, bank conflicts, coalescing
|
||||
|
||||
The host level provides the **orchestration**, while the block and warp levels provide the **execution**!
|
||||
|
||||
464
tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
464
tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md
Normal file
@@ -0,0 +1,464 @@
|
||||
# PracticeGemmKernel: Understanding the Kernel Entry Point
|
||||
|
||||
This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views.
|
||||
|
||||
## Overview
|
||||
|
||||
The `PracticeGemmKernel` is a templated struct that:
|
||||
1. Takes raw device memory pointers for matrices A, B, and C
|
||||
2. Wraps them into **tensor views** - logical, structured views over physical memory
|
||||
3. Dispatches to the host-level pipeline for computation
|
||||
|
||||
```cpp
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
// Step 1: Create tensor views over raw memory
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
// Step 2: Dispatch to host-level pipeline
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What are Tensor Views?
|
||||
|
||||
A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor.
|
||||
|
||||
### Key Components of a Tensor View:
|
||||
|
||||
1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers)
|
||||
2. **Raw Pointer**: Points to the actual data in memory
|
||||
3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A)
|
||||
4. **Strides**: How to navigate through memory to access elements
|
||||
5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction
|
||||
6. **Guaranteed Vector Stride**: The stride of those vectorizable elements
|
||||
|
||||
---
|
||||
|
||||
## The Memory Abstraction Hierarchy
|
||||
|
||||
CK Tile uses a three-layer abstraction to go from raw memory to structured tensors:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Layer 3: TENSOR VIEW │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ • Logical multi-dimensional structure │ │
|
||||
│ │ • Shape: (M, K) = (256, 32) │ │
|
||||
│ │ • Strides: (32, 1) for row-major layout │ │
|
||||
│ │ • Provides: operator[], coordinate-based access │ │
|
||||
│ │ • Knows: How to map (i,j) → linear offset │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 2: BUFFER VIEW │ │
|
||||
│ │ ┌─────────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ • Linear view of memory │ │ │
|
||||
│ │ │ • Pointer: p_data_ → device memory │ │ │
|
||||
│ │ │ • Size: Total number of elements │ │ │
|
||||
│ │ │ • Address space: global/LDS/generic │ │ │
|
||||
│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │
|
||||
│ │ └─────────────────────────────────────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ wraps │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Layer 1: RAW PHYSICAL MEMORY │ │
|
||||
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │
|
||||
│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │
|
||||
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │
|
||||
│ │ ↑ │ │
|
||||
│ │ p_a (raw pointer from hipMalloc) │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deep Dive: `make_naive_tensor_view`
|
||||
|
||||
Let's break down the function call for matrix A:
|
||||
|
||||
```cpp
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer to device memory
|
||||
make_tuple(M, K), // Shape: (256, 32)
|
||||
make_tuple(stride_a, 1), // Strides: (32, 1) - row-major
|
||||
number<8>{}, // Guaranteed vector length
|
||||
number<1>{} // Guaranteed vector stride
|
||||
);
|
||||
```
|
||||
|
||||
### Function Signature:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
// Step 1: Create tensor descriptor (shape + stride information)
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
// Step 2: Create buffer view (pointer + size + address space)
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
// Step 3: Combine into tensor view
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Parameter Breakdown
|
||||
|
||||
### 1. **Template Parameter: `address_space_enum::global`**
|
||||
|
||||
Specifies where the memory lives:
|
||||
- `global`: GPU global memory (DRAM) - slowest but largest
|
||||
- `lds`: Local Data Share (shared memory) - fast, limited size
|
||||
- `generic`: Generic address space
|
||||
- `vgpr`: Vector General Purpose Registers - fastest, smallest
|
||||
|
||||
In our case, `global` means the data is in GPU DRAM.
|
||||
|
||||
### 2. **`p_a` - Raw Pointer**
|
||||
|
||||
The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data.
|
||||
|
||||
### 3. **`make_tuple(M, K)` - Shape/Lengths**
|
||||
|
||||
Defines the logical dimensions of the tensor:
|
||||
- For matrix A: `(256, 32)` means 256 rows, 32 columns
|
||||
- This is the **logical view**, independent of how data is physically laid out
|
||||
|
||||
### 4. **`make_tuple(stride_a, 1)` - Strides**
|
||||
|
||||
Defines how to navigate through memory:
|
||||
- **Stride for dimension 0 (rows)**: `stride_a = K = 32`
|
||||
- To move to the next row, skip 32 elements
|
||||
- **Stride for dimension 1 (columns)**: `1`
|
||||
- To move to the next column, skip 1 element
|
||||
|
||||
**Row-major layout example:**
|
||||
```
|
||||
Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...]
|
||||
↑ ↑
|
||||
Row 0 starts here Row 1 starts here (offset = 32)
|
||||
|
||||
To access element A[i][j]:
|
||||
offset = i * stride_a + j * 1
|
||||
= i * 32 + j
|
||||
```
|
||||
|
||||
### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length**
|
||||
|
||||
This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."**
|
||||
|
||||
#### Why is this important?
|
||||
|
||||
Modern GPUs can load multiple elements in one instruction (vectorized loads):
|
||||
- `float4`: Load 4 floats at once
|
||||
- `float8`: Load 8 floats at once (if supported)
|
||||
|
||||
By specifying `number<8>{}`, we're telling the system:
|
||||
- "You can safely use vector loads of up to 8 elements"
|
||||
- "The memory alignment and layout support this"
|
||||
|
||||
**Example:**
|
||||
```cpp
|
||||
// Without vectorization (slow):
|
||||
for (int j = 0; j < 8; j++) {
|
||||
data[j] = memory[offset + j]; // 8 separate loads
|
||||
}
|
||||
|
||||
// With vectorization (fast):
|
||||
float8 vec = *reinterpret_cast<float8*>(&memory[offset]); // 1 load!
|
||||
```
|
||||
|
||||
### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride**
|
||||
|
||||
This specifies the **stride between consecutive vectorizable elements** in the last dimension.
|
||||
|
||||
- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)"
|
||||
- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively
|
||||
|
||||
**Why does this matter?**
|
||||
|
||||
For efficient vectorized loads, elements must be:
|
||||
1. **Contiguous** (stride = 1) ✓
|
||||
2. **Aligned** properly in memory
|
||||
3. **Within the same cache line** (ideally)
|
||||
|
||||
If the stride were `2`, it would mean:
|
||||
```
|
||||
A[i][0] is at offset 0
|
||||
A[i][1] is at offset 2 (not 1!)
|
||||
A[i][2] is at offset 4
|
||||
```
|
||||
This would prevent efficient vectorization.
|
||||
|
||||
---
|
||||
|
||||
## What is a Buffer View?
|
||||
|
||||
A **buffer view** is the middle layer between raw memory and tensor view. It provides:
|
||||
|
||||
### Core Responsibilities:
|
||||
|
||||
1. **Memory Management**
|
||||
- Holds the raw pointer: `T* p_data_`
|
||||
- Tracks buffer size: `BufferSizeType buffer_size_`
|
||||
- Knows the address space: `global`, `lds`, etc.
|
||||
|
||||
2. **Vectorized Access**
|
||||
```cpp
|
||||
template <typename VectorType>
|
||||
CK_TILE_DEVICE VectorType get(index_t offset);
|
||||
```
|
||||
- Provides efficient vector loads/stores
|
||||
- Handles alignment requirements
|
||||
|
||||
3. **Bounds Checking** (optional)
|
||||
```cpp
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto get(index_t i, index_t linear_offset);
|
||||
```
|
||||
- Can optionally check if access is within bounds
|
||||
- Returns invalid value (default 0) for out-of-bounds access
|
||||
|
||||
4. **Address Space Awareness**
|
||||
- Uses different load/store instructions based on address space
|
||||
- Global memory: `global_load`, `global_store`
|
||||
- LDS: `ds_read`, `ds_write`
|
||||
|
||||
### Buffer View Structure:
|
||||
|
||||
```cpp
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
struct buffer_view
|
||||
{
|
||||
T* p_data_; // Raw pointer
|
||||
BufferSizeType buffer_size_; // Total elements
|
||||
remove_cvref_t<T> invalid_element_value_; // Value for OOB access
|
||||
|
||||
// Access operators
|
||||
const T& operator[](index_t i) const; // Read
|
||||
T& operator()(index_t i); // Write
|
||||
|
||||
// Vectorized access
|
||||
template <typename VectorType>
|
||||
VectorType get(index_t offset);
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Visual Example: Matrix A Memory Layout
|
||||
|
||||
Let's visualize how matrix A (256×32, fp16) is organized:
|
||||
|
||||
### Raw Physical Memory (Linear):
|
||||
```
|
||||
GPU DRAM Address Space:
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Byte 0 │
|
||||
│ ↓ │
|
||||
│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │
|
||||
│ ↑ ↑ │
|
||||
│ Row 0 (32 elements) Row 1 (32 elements) │
|
||||
│ │
|
||||
│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
↑
|
||||
p_a (raw pointer)
|
||||
```
|
||||
|
||||
### Buffer View Layer:
|
||||
```
|
||||
buffer_view<address_space_enum::global, fp16_t, ...>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ p_data_ = p_a │
|
||||
│ buffer_size_ = 256 × 32 = 8,192 elements │
|
||||
│ address_space = global (DRAM) │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Linear indexing: buffer_view[i] → element at offset i │
|
||||
│ • Vectorized loads: get<float4>(offset) → load 4 fp16s at once│
|
||||
│ • Bounds checking: is offset < buffer_size_? │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Tensor View Layer:
|
||||
```
|
||||
tensor_view<buffer_view, tensor_descriptor>
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Shape: (256, 32) │
|
||||
│ Strides: (32, 1) │
|
||||
│ Guaranteed vector length: 8 │
|
||||
│ Guaranteed vector stride: 1 │
|
||||
│ │
|
||||
│ Logical 2D View: │
|
||||
│ Col: 0 1 2 ... 31 │
|
||||
│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│
|
||||
│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │
|
||||
│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │
|
||||
│ ... │
|
||||
│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │
|
||||
│ │
|
||||
│ Provides: │
|
||||
│ • Multi-dimensional indexing: A[i][j] │
|
||||
│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │
|
||||
│ • Tile window creation: Extract sub-tensors │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow: Raw Memory → Tensor View
|
||||
|
||||
Let's trace the complete transformation for matrix A:
|
||||
|
||||
### Step 1: Kernel Launch (Host Side)
|
||||
```cpp
|
||||
// On host: Allocate device memory
|
||||
hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer
|
||||
|
||||
// Launch kernel
|
||||
kernel<<<grid, block>>>(p_a, p_b, p_c, M, N, K, ...);
|
||||
```
|
||||
|
||||
### Step 2: Inside Kernel (Device Side)
|
||||
```cpp
|
||||
// Receive raw pointer
|
||||
const fp16_t* p_a; // Points to GPU DRAM
|
||||
|
||||
// Step 2a: Create tensor descriptor
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(256, 32), // Shape
|
||||
make_tuple(32, 1), // Strides
|
||||
number<8>{}, // Vector length
|
||||
number<1>{} // Vector stride
|
||||
);
|
||||
// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8"
|
||||
|
||||
// Step 2b: Create buffer view
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_a, // Raw pointer
|
||||
256 * 32 // Total elements
|
||||
);
|
||||
// buffer_view now wraps p_a with size and address space info
|
||||
|
||||
// Step 2c: Create tensor view
|
||||
auto a_dram = tensor_view{buffer_view, desc};
|
||||
// a_dram now provides structured, multi-dimensional access to p_a
|
||||
```
|
||||
|
||||
### Step 3: Using the Tensor View
|
||||
```cpp
|
||||
// Access element A[i][j]
|
||||
auto value = a_dram[make_tuple(i, j)];
|
||||
|
||||
// Create a tile window (sub-tensor)
|
||||
auto tile = make_tile_window(
|
||||
a_dram,
|
||||
make_tuple(16, 16), // 16×16 tile
|
||||
make_tuple(0, 0) // Starting at origin
|
||||
);
|
||||
|
||||
// Load tile into registers with vectorization
|
||||
auto tile_data = load_tile(tile); // Uses vector loads internally!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why This Abstraction?
|
||||
|
||||
### Benefits:
|
||||
|
||||
1. **Type Safety**: Can't accidentally access wrong dimensions
|
||||
2. **Performance**: Compiler knows about vectorization opportunities
|
||||
3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers)
|
||||
4. **Maintainability**: Logical structure separate from physical layout
|
||||
5. **Optimization**: Guaranteed vector properties enable aggressive optimizations
|
||||
|
||||
### Example: Without Tensor Views (Manual Indexing)
|
||||
```cpp
|
||||
// Ugly, error-prone, hard to optimize:
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j];
|
||||
// Hope the compiler vectorizes this? 🤞
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: With Tensor Views (Clean, Optimized)
|
||||
```cpp
|
||||
// Clean, safe, automatically vectorized:
|
||||
auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin);
|
||||
auto tile_data = load_tile(tile); // Vectorized loads guaranteed!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction:
|
||||
|
||||
1. **Raw Memory**: Linear array of bytes in GPU DRAM
|
||||
2. **Buffer View**: Adds size, address space, and vectorized access
|
||||
3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing
|
||||
|
||||
This abstraction enables:
|
||||
- ✅ Clean, readable code
|
||||
- ✅ Type-safe multi-dimensional access
|
||||
- ✅ Automatic vectorization
|
||||
- ✅ Flexible memory space handling
|
||||
- ✅ Efficient tile-based computation
|
||||
|
||||
The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation!
|
||||
|
||||
150
tutorial/ck_tile/01_naive_gemm/README.md
Normal file
150
tutorial/ck_tile/01_naive_gemm/README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# CK Tile Practice GEMM Example
|
||||
|
||||
This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system.
|
||||
|
||||
## CK Tile API Structure
|
||||
|
||||
In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**:
|
||||
|
||||
1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices
|
||||
2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads
|
||||
3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue)
|
||||
|
||||
## Overview
|
||||
|
||||
This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing:
|
||||
|
||||
- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch
|
||||
- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM
|
||||
- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory
|
||||
- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation
|
||||
|
||||
## Problem Setup and Data Flow
|
||||
|
||||
### Problem Size Configuration
|
||||
We set the problem size using the M, N and K variables:
|
||||
```cpp
|
||||
ck_tile::index_t M = 1024; // Number of rows in A and C
|
||||
ck_tile::index_t N = 512; // Number of columns in B and C
|
||||
ck_tile::index_t K = 256; // Number of columns in A, rows in B
|
||||
```
|
||||
|
||||
### Host Matrix Creation
|
||||
Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory:
|
||||
```cpp
|
||||
// Host tensors with proper strides
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides); // M × K
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides); // N × K
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides); // M × N
|
||||
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
// Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
a_device.ToDevice(a_host.data());
|
||||
```
|
||||
|
||||
### PracticeGemmShape Configuration
|
||||
A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile:
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave
|
||||
```
|
||||
- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix.
|
||||
|
||||
|
||||
- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile.
|
||||
- BlockTiles are further subdivided into WarpTiles.
|
||||
- WarpTiles over A and B similarly work together to calculate the WarpTile of C.
|
||||
|
||||
### Problem and Policy Composition
|
||||
```cpp
|
||||
// A Problem is composed from Shape and info about the data
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
// A Policy is created describing data-to-thread mapping
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
// A Kernel is then composed of Problem and Policy
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
```
|
||||
|
||||
### Kernel Launch
|
||||
`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`:
|
||||
```cpp
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCU>(
|
||||
gemm_kernel{}, // Kernel composed of Problem + Policy
|
||||
kGridSize, // Grid dimensions
|
||||
kBlockSize, // Block dimensions
|
||||
0, // Dynamic shared memory
|
||||
// Kernel arguments: device buffers and problem dimensions
|
||||
a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(),
|
||||
M, N, K, stride_a, stride_b, stride_c));
|
||||
```
|
||||
|
||||
### Result Verification
|
||||
The results from the kernel are compared with results from CPU based computation function:
|
||||
```cpp
|
||||
// CPU reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
|
||||
|
||||
// Device results
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
// Verify correctness
|
||||
bool pass = ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
```
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
The main program (`practice_gemm.cpp`) is the entry point for the runtime flow:
|
||||
|
||||
```cpp
|
||||
int main()
|
||||
{
|
||||
// 1. Define data types and problem sizes
|
||||
using ADataType = ck_tile::half_t;
|
||||
ck_tile::index_t M = 2048, N = 1024, K = 512;
|
||||
|
||||
// 2. Create host tensors and initialize
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
// 3. Allocate device memory and transfer data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
|
||||
// 4. Configure tile shapes
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// 5. Launch kernel
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
float ave_time = ck_tile::launch_kernel(/*...*/);
|
||||
|
||||
// 6. Verify results
|
||||
bool pass = verify_results(a_host, b_host, c_host);
|
||||
|
||||
// 7. Print performance metrics
|
||||
print_performance_metrics(ave_time, M, N, K);
|
||||
}
|
||||
```
|
||||
|
||||
## Building and Running
|
||||
|
||||
```bash
|
||||
# From composable_kernel root directory
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_practice_gemm -j
|
||||
|
||||
# Run with sample sizes
|
||||
./bin/tile_example_practice_gemm
|
||||
```
|
||||
This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework.
|
||||
506
tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md
Normal file
506
tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md
Normal file
@@ -0,0 +1,506 @@
|
||||
# Practice GEMM: Step-by-Step Code Walkthrough
|
||||
|
||||
This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API.
|
||||
|
||||
## Overview
|
||||
|
||||
We'll implement `C = A × B` where:
|
||||
- `A` is an `M × K` matrix
|
||||
- `B` is an `N × K` matrix (note: transposed layout)
|
||||
- `C` is an `M × N` matrix
|
||||
|
||||
The implementation uses a hierarchical tiling strategy with two levels:
|
||||
1. **Block Tiles**: Processed by thread blocks
|
||||
2. **Wave Tiles**: Processed by warps (wavefronts) within blocks
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Define Data Types
|
||||
|
||||
```cpp
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We use `half_t` (FP16) for input matrices A and B.
|
||||
- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy
|
||||
- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity
|
||||
---
|
||||
|
||||
## Step 2: Define Problem Size
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `M = 512`: Number of rows in A and C
|
||||
- `N = 256`: Number of columns in B and C
|
||||
- `K = 64`: Inner dimension (columns of A, rows of B)
|
||||
- Strides define memory layout (row-major for A and C, transposed for B)
|
||||
|
||||
**Memory Layout:**
|
||||
```
|
||||
Matrix A (M×K): Matrix B (N×K): Matrix C (M×N):
|
||||
[512 rows] [256 rows] [512 rows]
|
||||
[64 cols] [64 cols] [256 cols]
|
||||
stride = K stride = K stride = N
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Create Host Tensors
|
||||
|
||||
```cpp
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We create three tensors on the host (CPU) memory
|
||||
- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`)
|
||||
- `HostTensor` is a CK Tile utility class that manages CPU memory
|
||||
|
||||
**Stride explanation:**
|
||||
- For A: `stride_a = K` means moving to the next row requires skipping K elements
|
||||
- For B: `stride_b = K` means B is stored in transposed format
|
||||
- For C: `stride_c = N` means row-major layout
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Initialize Tensors with Random Data
|
||||
|
||||
```cpp
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- A and B are filled with random values in the range [-5.0, 5.0]
|
||||
- C is initialized to zero (will store the output)
|
||||
|
||||
**Optional: Print Tensor Contents**
|
||||
```cpp
|
||||
// Commented out in the code, but available for debugging:
|
||||
// a_host.print_first_n(10); // Print first 10 elements of A
|
||||
```
|
||||
|
||||
The `print_first_n()` helper function can display tensor contents for debugging purposes.
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Allocate Device Memory and Transfer Data
|
||||
|
||||
```cpp
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- `DeviceMem` allocates GPU memory matching the size of host tensors
|
||||
- The constructor **automatically transfers data from host to device**
|
||||
- This is a convenience wrapper around `hipMalloc` and `hipMemcpy`
|
||||
|
||||
**Memory Flow:**
|
||||
```
|
||||
CPU (Host) GPU (Device)
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
│ c_host │ ────────> │c_device │
|
||||
└─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 6: Configure Hierarchical Tiling
|
||||
|
||||
```cpp
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
- We define a two-level tiling hierarchy for the GEMM computation
|
||||
|
||||
### Block Tile (256 × 128 × 32)
|
||||
- **256**: M dimension per block (rows of A and C)
|
||||
- **128**: N dimension per block (columns of B and C)
|
||||
- **32**: K dimension per block (inner dimension)
|
||||
- Each block tile is processed by one **thread block** (256 threads)
|
||||
|
||||
### Wave Tile (16 × 16 × 16)
|
||||
- **16 × 16**: Output tile dimensions (M × N) per warp iteration
|
||||
- **16**: K dimension per warp iteration
|
||||
- Each wave tile is processed by one **warp** (64 threads on AMD GPUs)
|
||||
|
||||
**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile
|
||||
|
||||
**Important Note:**
|
||||
In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem.
|
||||
|
||||
### Tiling Visualization:
|
||||
|
||||
#### Matrix A (M × K = 256 × 32):
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles (16×16) │
|
||||
│ │ 16│ 16 │ in M×K space │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ │
|
||||
│ │ .. │ .. │ 16 tiles in M │
|
||||
│ ├────┼────┤ 2 tiles in K │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix B (N × K = 128 × 32):
|
||||
```
|
||||
┌──────────────────────────────┐
|
||||
│ One Block Tile (128 × 32) │
|
||||
│ ┌────┬────┐ │
|
||||
│ │16×│16× │ ← Wave tiles │
|
||||
│ │ 16│ 16 │ (16×16) │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ ├────┼────┤ 8 tiles in N │
|
||||
│ │ .. │ .. │ 2 tiles in K │
|
||||
│ ├────┼────┤ │
|
||||
│ │ │ │ │
|
||||
│ └────┴────┘ │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Matrix C (M × N = 256 × 128) - Output:
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ One Block Tile (256 × 128) │
|
||||
│ │
|
||||
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
|
||||
│ │16× │ │ │ │ │ │ │ │ │
|
||||
│ │ 16 │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │
|
||||
│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │
|
||||
│ │ │ │ │ │ │ │ │ │ │
|
||||
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
|
||||
│ │
|
||||
│ 16 wave tiles in M direction │
|
||||
│ 8 wave tiles in N direction │
|
||||
│ Total: 128 wave tiles (16×16 each) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### How Wave Tiles Combine (C = A × B):
|
||||
```
|
||||
Matrix A Matrix B (stored transposed N×K) Matrix C
|
||||
(256×32) (128×32) (256×128)
|
||||
|
||||
Row of A tiles: Row of B tiles: One wave tile in C:
|
||||
┌────┬────┐ ┌────┬────┐ ┌────┐
|
||||
│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16)
|
||||
└────┴────┘ └────┴────┘ └────┘
|
||||
16×16 each 16×16 each
|
||||
|
||||
Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ
|
||||
↑ ↑
|
||||
K=0..15 K=16..31
|
||||
|
||||
Each wave tile in C is computed by:
|
||||
- Taking one row of wave tiles from A (2 tiles along K)
|
||||
- Taking one row of wave tiles from B (2 tiles along K)
|
||||
Note: B is stored transposed (N×K), so a "row" in storage corresponds
|
||||
to a "column" in the logical B^T matrix used in computation
|
||||
- Performing dot product: Σ(A_k × B_k^T) for k=0,1
|
||||
```
|
||||
|
||||
**Key Insight:**
|
||||
- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B
|
||||
- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory
|
||||
- This is the fundamental operation repeated across all 128 wave tiles in C
|
||||
- Each warp computes one wave tile using MFMA instructions
|
||||
|
||||
---
|
||||
|
||||
## Step 7: Create Shape, Problem, and Policy Structs
|
||||
|
||||
```cpp
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. **Shape Struct**
|
||||
Encapsulates all tile shape information (BlockTile and WaveTile dimensions).
|
||||
|
||||
### 2. **Problem Struct**
|
||||
Holds complete problem description:
|
||||
- Data types (ADataType, BDataType, CDataType, AccDataType)
|
||||
- Shape information (BlockTile, WaveTile)
|
||||
|
||||
In more complex examples, this would also include:
|
||||
- Data layouts (row-major, column-major)
|
||||
- Mathematical operations (e.g., transposed GEMM)
|
||||
|
||||
### 3. **Policy Struct**
|
||||
Describes data movement and thread-to-data mapping:
|
||||
- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions
|
||||
- In more complex kernels, includes:
|
||||
- DRAM access patterns
|
||||
- LDS (Local Data Share) usage strategies
|
||||
- Thread distribution within blocks
|
||||
|
||||
**CK Tile Design Pattern:**
|
||||
```
|
||||
Kernel = Problem + Policy + Epilogue
|
||||
↑ ↑ ↑
|
||||
(What) (How) (Post-processing)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 8: Calculate Grid and Block Dimensions
|
||||
|
||||
```cpp
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1;
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### Grid Size Calculation
|
||||
```cpp
|
||||
kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N)
|
||||
= ceil(512 / 256) × ceil(256 / 128)
|
||||
= 2 × 2
|
||||
= 4 thread blocks
|
||||
```
|
||||
|
||||
Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction).
|
||||
|
||||
### Block Configuration
|
||||
- `kBlockSize = 256`: Each thread block has 256 threads
|
||||
- 256 threads / 64 threads per warp = **4 warps per block**
|
||||
- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity)
|
||||
|
||||
**Thread Hierarchy:**
|
||||
```
|
||||
GPU
|
||||
└── 1 Thread Block (Grid)
|
||||
└── 256 Threads
|
||||
├── Warp 0 (threads 0-63)
|
||||
├── Warp 1 (threads 64-127)
|
||||
├── Warp 2 (threads 128-191)
|
||||
└── Warp 3 (threads 192-255)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 9: Create and Launch the Kernel
|
||||
|
||||
```cpp
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. Kernel Composition
|
||||
```cpp
|
||||
using gemm_kernel = ck_tile::PracticeGemmKernel<Problem, Policy>;
|
||||
```
|
||||
The kernel is composed from Problem and Policy structs, following the CK Tile design pattern.
|
||||
|
||||
### 2. Kernel Launch
|
||||
`launch_kernel()` is a CK Tile utility that:
|
||||
- Launches the GPU kernel using HIP runtime
|
||||
- Measures execution time
|
||||
- Returns average execution time in milliseconds
|
||||
|
||||
### 3. Launch Parameters
|
||||
- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled
|
||||
- **Grid size**: `kGridSize = 1` - number of thread blocks
|
||||
- **Block size**: `kBlockSize = 256` - threads per block
|
||||
- **Shared memory**: `0` - no dynamic shared memory in this example
|
||||
- **Kernel arguments**: Device pointers and problem dimensions
|
||||
|
||||
### 4. Kernel Execution Flow
|
||||
```
|
||||
launch_kernel() calls gemm_kernel.operator()()
|
||||
↓
|
||||
PracticeGemmKernel::operator()
|
||||
↓
|
||||
Creates tensor views over device memory
|
||||
↓
|
||||
Calls block-level pipeline
|
||||
↓
|
||||
Block pipeline calls warp-level pipeline
|
||||
↓
|
||||
Warp pipeline calls MFMA instructions
|
||||
↓
|
||||
Results written back to C matrix
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 10: Verify Results
|
||||
|
||||
```cpp
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// Reference gemm on CPU
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
|
||||
// Copy GPU results back to host
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
|
||||
// Compare results
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
**What's happening:**
|
||||
|
||||
### 1. CPU Reference Implementation
|
||||
```cpp
|
||||
reference_basic_gemm<...>(a_host, b_host, c_host_ref);
|
||||
```
|
||||
Computes GEMM on CPU using a simple nested loop implementation (ground truth).
|
||||
|
||||
### 2. Copy GPU Results to Host
|
||||
```cpp
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
```
|
||||
Transfers the computed result from GPU memory back to CPU for comparison.
|
||||
|
||||
### 3. Error Checking
|
||||
```cpp
|
||||
ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
```
|
||||
Compares GPU and CPU results element-wise with tolerance:
|
||||
- **Relative error**: 1e-3 (0.1%)
|
||||
- **Absolute error**: 1e-3
|
||||
|
||||
**Verification Flow:**
|
||||
```
|
||||
CPU GPU
|
||||
┌─────────┐ ┌─────────┐
|
||||
│ a_host │ ────────> │a_device │
|
||||
│ b_host │ ────────> │b_device │
|
||||
└─────────┘ └─────────┘
|
||||
│ │
|
||||
↓ ↓
|
||||
reference_gemm() GPU kernel
|
||||
│ │
|
||||
↓ ↓
|
||||
┌──────────┐ ┌──────────┐
|
||||
│c_host_ref│ │c_device │
|
||||
└──────────┘ └──────────┘
|
||||
│ │
|
||||
│ ↓
|
||||
│ FromDevice()
|
||||
│ │
|
||||
↓ ↓
|
||||
└────> check_err() <───┘
|
||||
│
|
||||
↓
|
||||
Pass/Fail
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Execution Flow Summary
|
||||
|
||||
```
|
||||
1. Define data types (FP16 inputs, FP32 output)
|
||||
↓
|
||||
2. Set problem size (M=256, N=128, K=32)
|
||||
↓
|
||||
3. Create host tensors and initialize with random data
|
||||
↓
|
||||
4. Allocate device memory and transfer data (CPU → GPU)
|
||||
↓
|
||||
5. Configure hierarchical tiling (BlockTile, WaveTile)
|
||||
↓
|
||||
6. Create Shape, Problem, and Policy structs
|
||||
↓
|
||||
7. Calculate grid/block dimensions (1 block, 256 threads)
|
||||
↓
|
||||
8. Compose and launch kernel (Problem + Policy)
|
||||
↓
|
||||
9. Execute GEMM on GPU
|
||||
│ ├─ Block-level pipeline
|
||||
│ ├─ Warp-level pipeline
|
||||
│ └─ MFMA instructions
|
||||
↓
|
||||
10. Verify results (compare GPU vs CPU reference)
|
||||
↓
|
||||
11. Calculate and print performance metrics
|
||||
↓
|
||||
12. Return success/failure
|
||||
```
|
||||
|
||||
---
|
||||
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmBlockPolicy>
|
||||
struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t NPerBlock = BlockTile::at(number<1>{});
|
||||
static constexpr index_t KPerBlock = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr index_t MPerWave = WaveTile::at(number<0>{});
|
||||
static constexpr index_t NPerWave = WaveTile::at(number<1>{});
|
||||
static constexpr index_t KPerWave = WaveTile::at(number<2>{});
|
||||
|
||||
using BlockGemm =
|
||||
remove_cvref_t<decltype(Policy::template GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers
|
||||
b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp"
|
||||
#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmBlockPipelineProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = Shape_;
|
||||
};
|
||||
|
||||
struct PracticeGemmBlockPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline()
|
||||
{
|
||||
return PracticeGemmWarpPipelineASmemBSmemCreg<Problem>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
using ADataType = typename Problem_::ADataType;
|
||||
using BDataType = typename Problem_::BDataType;
|
||||
using CDataType = typename Problem_::CDataType;
|
||||
using AccDataType = typename Problem_::AccDataType;
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram_ref) const
|
||||
{
|
||||
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{});
|
||||
const auto NPerBlock = BlockTile::at(number<1>{});
|
||||
const auto KPerBlock = BlockTile::at(number<2>{});
|
||||
|
||||
// Number of block tile in the N direction to cover C (resultant) matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock);
|
||||
// Number of block tile in the M direction to cover C (resultant) matrix
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock);
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 0)
|
||||
// {
|
||||
// printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n);
|
||||
// printf("total number of tiles: %d\n", num_tile_m * num_tile_n);
|
||||
// }
|
||||
|
||||
// Get block id
|
||||
const auto id_block =
|
||||
get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1
|
||||
|
||||
// Map block id to tile id
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{});
|
||||
const auto tile_id_n = tile_id.at(number<1>{});
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 15)
|
||||
// {
|
||||
// printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n);
|
||||
// }
|
||||
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock;
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock;
|
||||
|
||||
// create a tile window over dram for A and B
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {tile_origin_m, 0});
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {tile_origin_n, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
auto c_window = make_tile_window(c_dram,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{tile_origin_m, tile_origin_n});
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp"
|
||||
#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmHostProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = remove_cvref_t<Shape_>;
|
||||
};
|
||||
|
||||
struct PracticeGemmHostPolicy
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline()
|
||||
{
|
||||
using PracticeGemmBlockPipelineProblem_ =
|
||||
PracticeGemmBlockPipelineProblem<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename Problem::Shape>;
|
||||
return PracticeGemmBlockPipelineAGmemBGmemCreg<PracticeGemmBlockPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
131
tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp
Normal file
131
tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "practice_gemm.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
// TODO: GemmTypeConfig
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
// ArgParser
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
// tensors on host (cpu)
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
|
||||
// initialize tensors
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
|
||||
// Print the tensors using the new print_first_n member function
|
||||
// std::cout << "Tensor A (first 10 elements): ";
|
||||
// a_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor B (first 10 elements): ";
|
||||
// b_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor C (first 10 elements): ";
|
||||
// c_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// Create device tensors of same size as host tensors and copy data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
|
||||
// TODO: BlockTileConfig
|
||||
// constexpr ck_tile::index_t warpSize = 64;
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl;
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
|
||||
|
||||
std::cout << "kBlockSize: " << kBlockSize << std::endl;
|
||||
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
|
||||
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
69
tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp
Normal file
69
tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockTile_, typename WaveTile_>
|
||||
struct PracticeGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using WaveTile = remove_cvref_t<WaveTile_>;
|
||||
|
||||
static constexpr index_t BlockTile_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t BlockTile_N = BlockTile::at(number<1>{});
|
||||
static constexpr index_t BlockTile_K = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr index_t WaveTile_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "practice_gemm_shape",
|
||||
concat('x', BlockTile_M, BlockTile_N, BlockTile_K),
|
||||
concat('x', WaveTile_M, WaveTile_N, WaveTile_K));
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
36
tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp
Normal file
36
tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n)
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int K = b_n_k.mDesc.get_lengths()[1];
|
||||
|
||||
auto f = [&](auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_n_k(n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1);
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmWarpPolicy>
|
||||
struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using WaveGemmShape = remove_cvref_t<typename Problem::Shape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct PracticeGemmWarpPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
7
tutorial/ck_tile/CMakeLists.txt
Normal file
7
tutorial/ck_tile/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
add_subdirectory(00_copy_kernel)
|
||||
add_subdirectory(01_naive_gemm)
|
||||
|
||||
Reference in New Issue
Block a user