mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
V5: experiment with multi-warp
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/reference/reference_mhc.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5_4warp.hpp"
|
||||
|
||||
// Parse command-line arguments for MHC benchmark
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -95,6 +96,8 @@ bool run_mhc_benchmark_impl(const ck_tile::ArgParser& arg_parser)
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Reverted to adaptive 2-warp (4-warp produced incorrect results)
|
||||
// using Problem = ck_tile::MHCProblemV5_4Warp<XDataType, ComputeDataType, YDataType>;
|
||||
using Problem = ck_tile::MHCProblemV5<XDataType, ComputeDataType, YDataType, MTile>;
|
||||
|
||||
// V5 kernel - split-K implementation with adaptive problem
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5_4warp.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
|
||||
@@ -62,7 +62,7 @@ struct MHCKernelV5
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// LDS for BlockGemm with padding: A[kMTile, kKTile+8] + B[kNTile, kKTile+8]
|
||||
// LDS for BlockGemm with padding: A[kMTile, kKTile+2] + B[kNTile, kKTile+2]
|
||||
constexpr index_t a_lds_size = kMTile * kKTilePadded * sizeof(XDataType);
|
||||
constexpr index_t b_lds_size = kNTile * kKTilePadded * sizeof(PhiDataType);
|
||||
return a_lds_size + b_lds_size;
|
||||
@@ -176,7 +176,8 @@ struct MHCKernelV5
|
||||
const index_t row_id = thread_id / threads_per_row;
|
||||
const index_t thread_in_row = thread_id % threads_per_row;
|
||||
|
||||
__shared__ ComputeDataType norm_reduction[kMTile][threads_per_row];
|
||||
// Use maximum possible threads_per_row for array size (kBlockSize for safety)
|
||||
__shared__ ComputeDataType norm_reduction[kMTile][kBlockSize];
|
||||
|
||||
if(row_id < kMTile)
|
||||
{
|
||||
|
||||
@@ -30,18 +30,18 @@ struct MHCProblemV5
|
||||
|
||||
static constexpr index_t kMTile = MTile_; // Adaptive M tile size
|
||||
|
||||
// Adaptive tile configuration
|
||||
// M=16 (default): Optimal for small/medium batches (B < 4096)
|
||||
// M=64: Optimal for large batches (B >= 4096)
|
||||
// N=32, K=128: Fixed for all configurations
|
||||
using BlockGemmShape = TileGemmShape<sequence<MTile_, 32, 128>, // BlockTile: Adaptive M
|
||||
// Adaptive tile configuration with K-loop optimization
|
||||
// M: Adaptive (16 or 64) based on batch size
|
||||
// N: 32 (fits output_dim=24 perfectly)
|
||||
// K: 128 (allows more K-tiles per block)
|
||||
using BlockGemmShape = TileGemmShape<sequence<MTile_, 32, 128>, // BlockTile
|
||||
sequence<1, 1, 1>, // BlockWarps: 1 warp
|
||||
sequence<MTile_, 32, 128>>; // WarpTile: matches BlockTile
|
||||
sequence<MTile_, 32, 128>>; // WarpTile
|
||||
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 4;
|
||||
|
||||
// 1 warp × 64 threads/warp = 64 threads (same as V4)
|
||||
// 1 warp × 64 threads/warp = 64 threads
|
||||
using BlockShape = Generic2dBlockShape<sequence<1, 64>, sequence<1, 64>, sequence<1, 1>>;
|
||||
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -79,45 +79,42 @@ struct MHCProblemV5
|
||||
|
||||
CK_TILE_HOST static const std::string GetName() { return "MHCProblemV5"; }
|
||||
|
||||
// Adaptive tile distribution for loading X (input matrix)
|
||||
// X is [Batch, nC] row-major, we load kM×kK tiles
|
||||
// For M=16: H0 (M): [grid=1, warp=1, thread=16, vector=1] = 16
|
||||
// For M=64: H0 (M): [grid=4, warp=1, thread=16, vector=1] = 64
|
||||
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128 (same for all)
|
||||
// Tile distribution for loading X: Adaptive_M × 128
|
||||
// M: Adaptive (16 or 64)
|
||||
// K: 128 = 1×1×4×16 (4 threads × 32 vector for better coalescing)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t m_grid = MTile_ / 16; // M=16 → grid=1, M=64 → grid=4
|
||||
constexpr index_t m_grid = MTile_ / 16;
|
||||
|
||||
using XTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<m_grid, 1, 16, 1>, // H0 (M): adaptive grid based on MTile_
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
|
||||
sequence<1, 1, 2, 2>, // Y→RH major: data layout
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
|
||||
using XTileDistEncoding =
|
||||
tile_distribution_encoding<sequence<>, // R: No replication
|
||||
tuple<sequence<m_grid, 1, 16, 1>, // H0 (M): adaptive
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): 128 = 2×1×4×16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
|
||||
sequence<1, 1, 2, 2>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(XTileDistEncoding{});
|
||||
}
|
||||
|
||||
// Tile distribution for loading Phi (weight matrix)
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×128)
|
||||
// H0 (N): [grid=1, warp=1, thread=16, vector=2] = 32
|
||||
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128
|
||||
// Tile distribution for loading Phi: 32 × 128
|
||||
// N: 32 = 1×1×16×2 (16 threads × 2 vector, fits output_dim=24)
|
||||
// K: 128 = 2×1×4×16 (matches X distribution)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using PhiTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 2>, // H0 (N): grid=1, warp=1, thread=16, vector=2
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
|
||||
sequence<1, 1, 2, 2>, // Y→RH major: data layout
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
|
||||
using PhiTileDistEncoding =
|
||||
tile_distribution_encoding<sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 2>, // H0 (N): 32 = 1×1×16×2
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): 128 = 2×1×4×16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
|
||||
sequence<1, 1, 2, 2>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(PhiTileDistEncoding{});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user