diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp index f72b9913ab..9fea406fa1 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp @@ -51,12 +51,13 @@ struct MHCKernelV4 return a_lds_size + b_lds_size; } - // Grid configuration: 2D grid over (batch, output_dim) - CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim) + // Grid configuration: 1D grid over batch dimension only + // Each block processes full output dimension (kNTile=32 covers output_dim=24) + CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, + [[maybe_unused]] index_t output_dim) { const index_t grid_m = (batch + kMTile - 1) / kMTile; - const index_t grid_n = (output_dim + kNTile - 1) / kNTile; - return make_tuple(grid_m, grid_n); + return make_tuple(grid_m, 1); // 1D grid: only tile in M dimension } CK_TILE_DEVICE void operator()(const XDataType* p_x, @@ -72,16 +73,12 @@ struct MHCKernelV4 [[maybe_unused]] float alpha_res = 1.0f, [[maybe_unused]] float bias = 0.0f) const { - // 2D block indexing - const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile; + // 1D block indexing: only tile in M (batch) dimension const index_t block_id = get_block_id(); - const index_t block_m = block_id / grid_n_size; - const index_t block_n = block_id % grid_n_size; + const index_t batch_start = block_id * kMTile; + const index_t out_start = 0; // Always process from output index 0 - const index_t batch_start = block_m * kMTile; - const index_t out_start = block_n * kNTile; - - if(batch_start >= batch || out_start >= output_dim) + if(batch_start >= batch) return; const index_t tid = get_thread_id(); @@ -213,31 +210,62 @@ struct MHCKernelV4 } block_sync_lds(); - // Warp-level reduction for each batch element - // Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each - // element - constexpr index_t threads_per_element = - kBlockSize / kMTile; // 64/16 = 4 threads per batch element + // Adaptive block-level reduction for each batch element + // Current: 128 threads, kMTile=16 → 8 threads per batch element + constexpr index_t threads_per_element = kBlockSize / kMTile; + constexpr index_t warp_size = 64; // AMD warp size for(index_t local_m = 0; local_m < kMTile; ++local_m) { ComputeDataType my_sum = thread_sum_squares[local_m]; - // Warp shuffle reduction within threads handling this batch element - // Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together const index_t my_group = tid / threads_per_element; const index_t lane_in_group = tid % threads_per_element; if(my_group == local_m) { -// Reduce within this group of 4 threads using warp shuffle + // Step 1: Warp-level reduction (works within a single warp) + constexpr index_t warp_reduce_size = + (threads_per_element <= warp_size) ? threads_per_element : warp_size; + #pragma unroll - for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2) + for(index_t offset = warp_reduce_size / 2; offset > 0; offset /= 2) { my_sum += __shfl_down(my_sum, offset); } - // First thread in group writes to shared memory + // Step 2: Cross-warp reduction if needed (threads_per_element > warp_size) + if constexpr(threads_per_element > warp_size) + { + __shared__ ComputeDataType + warp_partial_sums[kMTile] + [(threads_per_element + warp_size - 1) / warp_size]; + + const index_t warp_id = lane_in_group / warp_size; + const index_t lane_id = lane_in_group % warp_size; + + // First thread in each warp writes partial sum + if(lane_id == 0) + { + warp_partial_sums[local_m][warp_id] = my_sum; + } + block_sync_lds(); + + // First warp does final reduction across warp partial sums + constexpr index_t num_warps_per_element = + (threads_per_element + warp_size - 1) / warp_size; + if(lane_in_group < num_warps_per_element) + { + my_sum = warp_partial_sums[local_m][lane_in_group]; +#pragma unroll + for(index_t offset = num_warps_per_element / 2; offset > 0; offset /= 2) + { + my_sum += __shfl_down(my_sum, offset); + } + } + } + + // First thread in group writes final result if(lane_in_group == 0) { sum_squares_shared[local_m] = my_sum; diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp index fc95b5d414..f9fc163b4f 100644 --- a/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp +++ b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp @@ -27,20 +27,20 @@ struct MHCProblemV4 using CDataType = ComputeDataType; // Output/accumulator matrix C // BlockGemmShape with kM, kN, kK members for BlockGemm - // Using 16x16x16 tiles with 1 warp per block - using BlockGemmShape = TileGemmShape, // BlockTile (M, N, K) - sequence<1, 1, 1>, // BlockWarps (1 warp per block) - sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA) + // Phase 2 Simplified: 1D grid with 1 warp, process full output (N=32) + // Use 2 MFMA calls per warp to cover 32 outputs (2 × 16 = 32) + using BlockGemmShape = TileGemmShape, // BlockTile (M=16, N=32, K=16) + sequence<1, 1, 1>, // BlockWarps (1 warp total) + sequence<16, 32, 16>>; // WarpTile (16x32x16) // Vector sizes for loading static constexpr index_t VectorSizeA = 4; static constexpr index_t VectorSizeB = 4; // Derive BlockShape from BlockGemmShape - // Match V3's approach: use a simple 1×64 configuration for 1 warp - // This ensures proper tile distribution for load_tile/store_tile operations + // Back to 1 warp (64 threads) for proven norm reduction using BlockShape = - Generic2dBlockShape, // BlockTile [1, 64] - simple layout for 1 warp + Generic2dBlockShape, // BlockTile [1, 64] - layout for 1 warp sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp) sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape @@ -86,17 +86,17 @@ struct MHCProblemV4 // Tile distribution for loading X (input matrix) from global memory // X is [Batch, nC] row-major, we load kM×kK tiles (16×16) - // For a 16×16 tile with 64 threads (1 warp): + // With 1 warp (64 threads): // M: 1 repeat × 1 warp × 16 threads × 1 vector = 16 // K: 1 repeat × 1 warp × 4 threads × 4 vector = 16 - // Total threads: 1 warp × (16×4) = 64 threads ✓ + // Total threads: 1 warp × 64 threads = 64 threads ✓ CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution() { using namespace ck_tile; // H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16 // H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16 - // P→RH: Warp layout = 1 warp in M × 1 warp in K + // P→RH: Warp layout = 1 warp in M × 1 warp in K = 1 warp total // Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp // Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized) using XTileDistEncoding = tile_distribution_encoding< @@ -112,15 +112,27 @@ struct MHCProblemV4 } // Tile distribution for loading Phi (weight matrix) from global memory - // Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16) + // Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×16) + // With 1 warp (64 threads), use 2 repeats in N to cover 32 elements: + // N: 2 repeat × 1 warp × 8 threads × 2 vector = 32 + // K: 1 repeat × 1 warp × 4 threads × 4 vector = 16 + // Total threads: 1 warp × 64 threads = 64 threads ✓ CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution() { using namespace ck_tile; - // Same distribution as X for 16×16 tiles + // H0 (N dimension): [repeat=2, warp=1, thread=8, vector=2] = 32 + // H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16 + // P→RH: Warp layout = 1 warp in N × 1 warp in K = 1 warp total + // Thread layout = 8 threads in N × 4 threads in K = 32 threads/warp... wait that's + // only 32! + // Need to recalculate: 8×4=32 threads, but we have 64 threads/warp + // Better: N: 1 repeat × 1 warp × 16 threads × 2 vector = 32 + // K: 1 repeat × 1 warp × 4 threads × 4 vector = 16 + // Thread layout: 16×4 = 64 threads ✓ using PhiTileDistEncoding = tile_distribution_encoding< sequence<>, // R: No replication - tuple, // H0 (N): repeat=1, warp=1, thread=16, vector=1 + tuple, // H0 (N): repeat=1, warp=1, thread=16, vector=2 sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4 tuple, sequence<1, 2>>, // P→RH major tuple, sequence<2, 2>>, // P→RH minor