V4: update grid shape to 1D (B) instead of 2D (B,n)

This commit is contained in:
Damien Lejeune
2026-02-11 09:46:45 +00:00
parent 63dcefffc3
commit 055de18707
2 changed files with 75 additions and 35 deletions

View File

@@ -51,12 +51,13 @@ struct MHCKernelV4
return a_lds_size + b_lds_size;
}
// Grid configuration: 2D grid over (batch, output_dim)
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim)
// Grid configuration: 1D grid over batch dimension only
// Each block processes full output dimension (kNTile=32 covers output_dim=24)
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch,
[[maybe_unused]] index_t output_dim)
{
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t grid_n = (output_dim + kNTile - 1) / kNTile;
return make_tuple(grid_m, grid_n);
return make_tuple(grid_m, 1); // 1D grid: only tile in M dimension
}
CK_TILE_DEVICE void operator()(const XDataType* p_x,
@@ -72,16 +73,12 @@ struct MHCKernelV4
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f) const
{
// 2D block indexing
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
// 1D block indexing: only tile in M (batch) dimension
const index_t block_id = get_block_id();
const index_t block_m = block_id / grid_n_size;
const index_t block_n = block_id % grid_n_size;
const index_t batch_start = block_id * kMTile;
const index_t out_start = 0; // Always process from output index 0
const index_t batch_start = block_m * kMTile;
const index_t out_start = block_n * kNTile;
if(batch_start >= batch || out_start >= output_dim)
if(batch_start >= batch)
return;
const index_t tid = get_thread_id();
@@ -213,31 +210,62 @@ struct MHCKernelV4
}
block_sync_lds();
// Warp-level reduction for each batch element
// Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each
// element
constexpr index_t threads_per_element =
kBlockSize / kMTile; // 64/16 = 4 threads per batch element
// Adaptive block-level reduction for each batch element
// Current: 128 threads, kMTile=16 → 8 threads per batch element
constexpr index_t threads_per_element = kBlockSize / kMTile;
constexpr index_t warp_size = 64; // AMD warp size
for(index_t local_m = 0; local_m < kMTile; ++local_m)
{
ComputeDataType my_sum = thread_sum_squares[local_m];
// Warp shuffle reduction within threads handling this batch element
// Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together
const index_t my_group = tid / threads_per_element;
const index_t lane_in_group = tid % threads_per_element;
if(my_group == local_m)
{
// Reduce within this group of 4 threads using warp shuffle
// Step 1: Warp-level reduction (works within a single warp)
constexpr index_t warp_reduce_size =
(threads_per_element <= warp_size) ? threads_per_element : warp_size;
#pragma unroll
for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2)
for(index_t offset = warp_reduce_size / 2; offset > 0; offset /= 2)
{
my_sum += __shfl_down(my_sum, offset);
}
// First thread in group writes to shared memory
// Step 2: Cross-warp reduction if needed (threads_per_element > warp_size)
if constexpr(threads_per_element > warp_size)
{
__shared__ ComputeDataType
warp_partial_sums[kMTile]
[(threads_per_element + warp_size - 1) / warp_size];
const index_t warp_id = lane_in_group / warp_size;
const index_t lane_id = lane_in_group % warp_size;
// First thread in each warp writes partial sum
if(lane_id == 0)
{
warp_partial_sums[local_m][warp_id] = my_sum;
}
block_sync_lds();
// First warp does final reduction across warp partial sums
constexpr index_t num_warps_per_element =
(threads_per_element + warp_size - 1) / warp_size;
if(lane_in_group < num_warps_per_element)
{
my_sum = warp_partial_sums[local_m][lane_in_group];
#pragma unroll
for(index_t offset = num_warps_per_element / 2; offset > 0; offset /= 2)
{
my_sum += __shfl_down(my_sum, offset);
}
}
}
// First thread in group writes final result
if(lane_in_group == 0)
{
sum_squares_shared[local_m] = my_sum;

View File

@@ -27,20 +27,20 @@ struct MHCProblemV4
using CDataType = ComputeDataType; // Output/accumulator matrix C
// BlockGemmShape with kM, kN, kK members for BlockGemm
// Using 16x16x16 tiles with 1 warp per block
using BlockGemmShape = TileGemmShape<sequence<16, 16, 16>, // BlockTile (M, N, K)
sequence<1, 1, 1>, // BlockWarps (1 warp per block)
sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA)
// Phase 2 Simplified: 1D grid with 1 warp, process full output (N=32)
// Use 2 MFMA calls per warp to cover 32 outputs (2 × 16 = 32)
using BlockGemmShape = TileGemmShape<sequence<16, 32, 16>, // BlockTile (M=16, N=32, K=16)
sequence<1, 1, 1>, // BlockWarps (1 warp total)
sequence<16, 32, 16>>; // WarpTile (16x32x16)
// Vector sizes for loading
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// Derive BlockShape from BlockGemmShape
// Match V3's approach: use a simple 1×64 configuration for 1 warp
// This ensures proper tile distribution for load_tile/store_tile operations
// Back to 1 warp (64 threads) for proven norm reduction
using BlockShape =
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - simple layout for 1 warp
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - layout for 1 warp
sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp)
sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape
@@ -86,17 +86,17 @@ struct MHCProblemV4
// Tile distribution for loading X (input matrix) from global memory
// X is [Batch, nC] row-major, we load kM×kK tiles (16×16)
// For a 16×16 tile with 64 threads (1 warp):
// With 1 warp (64 threads):
// M: 1 repeat × 1 warp × 16 threads × 1 vector = 16
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
// Total threads: 1 warp × (16×4) = 64 threads ✓
// Total threads: 1 warp × 64 threads = 64 threads ✓
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
{
using namespace ck_tile;
// H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
// P→RH: Warp layout = 1 warp in M × 1 warp in K
// P→RH: Warp layout = 1 warp in M × 1 warp in K = 1 warp total
// Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp
// Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized)
using XTileDistEncoding = tile_distribution_encoding<
@@ -112,15 +112,27 @@ struct MHCProblemV4
}
// Tile distribution for loading Phi (weight matrix) from global memory
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16)
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×16)
// With 1 warp (64 threads), use 2 repeats in N to cover 32 elements:
// N: 2 repeat × 1 warp × 8 threads × 2 vector = 32
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
// Total threads: 1 warp × 64 threads = 64 threads ✓
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
{
using namespace ck_tile;
// Same distribution as X for 16×16 tiles
// H0 (N dimension): [repeat=2, warp=1, thread=8, vector=2] = 32
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
// P→RH: Warp layout = 1 warp in N × 1 warp in K = 1 warp total
// Thread layout = 8 threads in N × 4 threads in K = 32 threads/warp... wait that's
// only 32!
// Need to recalculate: 8×4=32 threads, but we have 64 threads/warp
// Better: N: 1 repeat × 1 warp × 16 threads × 2 vector = 32
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
// Thread layout: 16×4 = 64 threads ✓
using PhiTileDistEncoding = tile_distribution_encoding<
sequence<>, // R: No replication
tuple<sequence<1, 1, 16, 1>, // H0 (N): repeat=1, warp=1, thread=16, vector=1
tuple<sequence<1, 1, 16, 2>, // H0 (N): repeat=1, warp=1, thread=16, vector=2
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor