V5: experiment with multi-warp

This commit is contained in:
Damien Lejeune
2026-02-12 14:39:20 +00:00
parent 0d7a341d27
commit 11d1c40655
4 changed files with 37 additions and 35 deletions

View File

@@ -13,6 +13,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5_4warp.hpp"
// Parse command-line arguments for MHC benchmark
auto create_args(int argc, char* argv[])
@@ -95,6 +96,8 @@ bool run_mhc_benchmark_impl(const ck_tile::ArgParser& arg_parser)
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Reverted to adaptive 2-warp (4-warp produced incorrect results)
// using Problem = ck_tile::MHCProblemV5_4Warp<XDataType, ComputeDataType, YDataType>;
using Problem = ck_tile::MHCProblemV5<XDataType, ComputeDataType, YDataType, MTile>;
// V5 kernel - split-K implementation with adaptive problem

View File

@@ -13,6 +13,7 @@
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5_4warp.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"

View File

@@ -62,7 +62,7 @@ struct MHCKernelV5
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// LDS for BlockGemm with padding: A[kMTile, kKTile+8] + B[kNTile, kKTile+8]
// LDS for BlockGemm with padding: A[kMTile, kKTile+2] + B[kNTile, kKTile+2]
constexpr index_t a_lds_size = kMTile * kKTilePadded * sizeof(XDataType);
constexpr index_t b_lds_size = kNTile * kKTilePadded * sizeof(PhiDataType);
return a_lds_size + b_lds_size;
@@ -176,7 +176,8 @@ struct MHCKernelV5
const index_t row_id = thread_id / threads_per_row;
const index_t thread_in_row = thread_id % threads_per_row;
__shared__ ComputeDataType norm_reduction[kMTile][threads_per_row];
// Use maximum possible threads_per_row for array size (kBlockSize for safety)
__shared__ ComputeDataType norm_reduction[kMTile][kBlockSize];
if(row_id < kMTile)
{

View File

@@ -30,18 +30,18 @@ struct MHCProblemV5
static constexpr index_t kMTile = MTile_; // Adaptive M tile size
// Adaptive tile configuration
// M=16 (default): Optimal for small/medium batches (B < 4096)
// M=64: Optimal for large batches (B >= 4096)
// N=32, K=128: Fixed for all configurations
using BlockGemmShape = TileGemmShape<sequence<MTile_, 32, 128>, // BlockTile: Adaptive M
// Adaptive tile configuration with K-loop optimization
// M: Adaptive (16 or 64) based on batch size
// N: 32 (fits output_dim=24 perfectly)
// K: 128 (allows more K-tiles per block)
using BlockGemmShape = TileGemmShape<sequence<MTile_, 32, 128>, // BlockTile
sequence<1, 1, 1>, // BlockWarps: 1 warp
sequence<MTile_, 32, 128>>; // WarpTile: matches BlockTile
sequence<MTile_, 32, 128>>; // WarpTile
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// 1 warp × 64 threads/warp = 64 threads (same as V4)
// 1 warp × 64 threads/warp = 64 threads
using BlockShape = Generic2dBlockShape<sequence<1, 64>, sequence<1, 64>, sequence<1, 1>>;
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
@@ -79,45 +79,42 @@ struct MHCProblemV5
CK_TILE_HOST static const std::string GetName() { return "MHCProblemV5"; }
// Adaptive tile distribution for loading X (input matrix)
// X is [Batch, nC] row-major, we load kM×kK tiles
// For M=16: H0 (M): [grid=1, warp=1, thread=16, vector=1] = 16
// For M=64: H0 (M): [grid=4, warp=1, thread=16, vector=1] = 64
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128 (same for all)
// Tile distribution for loading X: Adaptive_M × 128
// M: Adaptive (16 or 64)
// K: 128 = 1×1×4×16 (4 threads × 32 vector for better coalescing)
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
{
using namespace ck_tile;
constexpr index_t m_grid = MTile_ / 16; // M=16 → grid=1, M=64 → grid=4
constexpr index_t m_grid = MTile_ / 16;
using XTileDistEncoding = tile_distribution_encoding<
sequence<>, // R: No replication
tuple<sequence<m_grid, 1, 16, 1>, // H0 (M): adaptive grid based on MTile_
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
sequence<1, 1, 2, 2>, // Y→RH major: data layout
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
using XTileDistEncoding =
tile_distribution_encoding<sequence<>, // R: No replication
tuple<sequence<m_grid, 1, 16, 1>, // H0 (M): adaptive
sequence<2, 1, 4, 16>>, // H1 (K): 128 = 2×1×4×16
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
sequence<1, 1, 2, 2>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(XTileDistEncoding{});
}
// Tile distribution for loading Phi (weight matrix)
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×128)
// H0 (N): [grid=1, warp=1, thread=16, vector=2] = 32
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128
// Tile distribution for loading Phi: 32 × 128
// N: 32 = 1×1×16×2 (16 threads × 2 vector, fits output_dim=24)
// K: 128 = 2×1×4×16 (matches X distribution)
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
{
using namespace ck_tile;
using PhiTileDistEncoding = tile_distribution_encoding<
sequence<>, // R: No replication
tuple<sequence<1, 1, 16, 2>, // H0 (N): grid=1, warp=1, thread=16, vector=2
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
sequence<1, 1, 2, 2>, // Y→RH major: data layout
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
using PhiTileDistEncoding =
tile_distribution_encoding<sequence<>, // R: No replication
tuple<sequence<1, 1, 16, 2>, // H0 (N): 32 = 1×1×16×2
sequence<2, 1, 4, 16>>, // H1 (K): 128 = 2×1×4×16
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
sequence<1, 1, 2, 2>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(PhiTileDistEncoding{});
}