mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
V4: update grid shape to 1D (B) instead of 2D (B,n)
This commit is contained in:
@@ -51,12 +51,13 @@ struct MHCKernelV4
|
||||
return a_lds_size + b_lds_size;
|
||||
}
|
||||
|
||||
// Grid configuration: 2D grid over (batch, output_dim)
|
||||
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim)
|
||||
// Grid configuration: 1D grid over batch dimension only
|
||||
// Each block processes full output dimension (kNTile=32 covers output_dim=24)
|
||||
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch,
|
||||
[[maybe_unused]] index_t output_dim)
|
||||
{
|
||||
const index_t grid_m = (batch + kMTile - 1) / kMTile;
|
||||
const index_t grid_n = (output_dim + kNTile - 1) / kNTile;
|
||||
return make_tuple(grid_m, grid_n);
|
||||
return make_tuple(grid_m, 1); // 1D grid: only tile in M dimension
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x,
|
||||
@@ -72,16 +73,12 @@ struct MHCKernelV4
|
||||
[[maybe_unused]] float alpha_res = 1.0f,
|
||||
[[maybe_unused]] float bias = 0.0f) const
|
||||
{
|
||||
// 2D block indexing
|
||||
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
|
||||
// 1D block indexing: only tile in M (batch) dimension
|
||||
const index_t block_id = get_block_id();
|
||||
const index_t block_m = block_id / grid_n_size;
|
||||
const index_t block_n = block_id % grid_n_size;
|
||||
const index_t batch_start = block_id * kMTile;
|
||||
const index_t out_start = 0; // Always process from output index 0
|
||||
|
||||
const index_t batch_start = block_m * kMTile;
|
||||
const index_t out_start = block_n * kNTile;
|
||||
|
||||
if(batch_start >= batch || out_start >= output_dim)
|
||||
if(batch_start >= batch)
|
||||
return;
|
||||
|
||||
const index_t tid = get_thread_id();
|
||||
@@ -213,31 +210,62 @@ struct MHCKernelV4
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Warp-level reduction for each batch element
|
||||
// Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each
|
||||
// element
|
||||
constexpr index_t threads_per_element =
|
||||
kBlockSize / kMTile; // 64/16 = 4 threads per batch element
|
||||
// Adaptive block-level reduction for each batch element
|
||||
// Current: 128 threads, kMTile=16 → 8 threads per batch element
|
||||
constexpr index_t threads_per_element = kBlockSize / kMTile;
|
||||
constexpr index_t warp_size = 64; // AMD warp size
|
||||
|
||||
for(index_t local_m = 0; local_m < kMTile; ++local_m)
|
||||
{
|
||||
ComputeDataType my_sum = thread_sum_squares[local_m];
|
||||
|
||||
// Warp shuffle reduction within threads handling this batch element
|
||||
// Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together
|
||||
const index_t my_group = tid / threads_per_element;
|
||||
const index_t lane_in_group = tid % threads_per_element;
|
||||
|
||||
if(my_group == local_m)
|
||||
{
|
||||
// Reduce within this group of 4 threads using warp shuffle
|
||||
// Step 1: Warp-level reduction (works within a single warp)
|
||||
constexpr index_t warp_reduce_size =
|
||||
(threads_per_element <= warp_size) ? threads_per_element : warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2)
|
||||
for(index_t offset = warp_reduce_size / 2; offset > 0; offset /= 2)
|
||||
{
|
||||
my_sum += __shfl_down(my_sum, offset);
|
||||
}
|
||||
|
||||
// First thread in group writes to shared memory
|
||||
// Step 2: Cross-warp reduction if needed (threads_per_element > warp_size)
|
||||
if constexpr(threads_per_element > warp_size)
|
||||
{
|
||||
__shared__ ComputeDataType
|
||||
warp_partial_sums[kMTile]
|
||||
[(threads_per_element + warp_size - 1) / warp_size];
|
||||
|
||||
const index_t warp_id = lane_in_group / warp_size;
|
||||
const index_t lane_id = lane_in_group % warp_size;
|
||||
|
||||
// First thread in each warp writes partial sum
|
||||
if(lane_id == 0)
|
||||
{
|
||||
warp_partial_sums[local_m][warp_id] = my_sum;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// First warp does final reduction across warp partial sums
|
||||
constexpr index_t num_warps_per_element =
|
||||
(threads_per_element + warp_size - 1) / warp_size;
|
||||
if(lane_in_group < num_warps_per_element)
|
||||
{
|
||||
my_sum = warp_partial_sums[local_m][lane_in_group];
|
||||
#pragma unroll
|
||||
for(index_t offset = num_warps_per_element / 2; offset > 0; offset /= 2)
|
||||
{
|
||||
my_sum += __shfl_down(my_sum, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First thread in group writes final result
|
||||
if(lane_in_group == 0)
|
||||
{
|
||||
sum_squares_shared[local_m] = my_sum;
|
||||
|
||||
@@ -27,20 +27,20 @@ struct MHCProblemV4
|
||||
using CDataType = ComputeDataType; // Output/accumulator matrix C
|
||||
|
||||
// BlockGemmShape with kM, kN, kK members for BlockGemm
|
||||
// Using 16x16x16 tiles with 1 warp per block
|
||||
using BlockGemmShape = TileGemmShape<sequence<16, 16, 16>, // BlockTile (M, N, K)
|
||||
sequence<1, 1, 1>, // BlockWarps (1 warp per block)
|
||||
sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA)
|
||||
// Phase 2 Simplified: 1D grid with 1 warp, process full output (N=32)
|
||||
// Use 2 MFMA calls per warp to cover 32 outputs (2 × 16 = 32)
|
||||
using BlockGemmShape = TileGemmShape<sequence<16, 32, 16>, // BlockTile (M=16, N=32, K=16)
|
||||
sequence<1, 1, 1>, // BlockWarps (1 warp total)
|
||||
sequence<16, 32, 16>>; // WarpTile (16x32x16)
|
||||
|
||||
// Vector sizes for loading
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 4;
|
||||
|
||||
// Derive BlockShape from BlockGemmShape
|
||||
// Match V3's approach: use a simple 1×64 configuration for 1 warp
|
||||
// This ensures proper tile distribution for load_tile/store_tile operations
|
||||
// Back to 1 warp (64 threads) for proven norm reduction
|
||||
using BlockShape =
|
||||
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - simple layout for 1 warp
|
||||
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - layout for 1 warp
|
||||
sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp)
|
||||
sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape
|
||||
|
||||
@@ -86,17 +86,17 @@ struct MHCProblemV4
|
||||
|
||||
// Tile distribution for loading X (input matrix) from global memory
|
||||
// X is [Batch, nC] row-major, we load kM×kK tiles (16×16)
|
||||
// For a 16×16 tile with 64 threads (1 warp):
|
||||
// With 1 warp (64 threads):
|
||||
// M: 1 repeat × 1 warp × 16 threads × 1 vector = 16
|
||||
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
|
||||
// Total threads: 1 warp × (16×4) = 64 threads ✓
|
||||
// Total threads: 1 warp × 64 threads = 64 threads ✓
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16
|
||||
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
|
||||
// P→RH: Warp layout = 1 warp in M × 1 warp in K
|
||||
// P→RH: Warp layout = 1 warp in M × 1 warp in K = 1 warp total
|
||||
// Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp
|
||||
// Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized)
|
||||
using XTileDistEncoding = tile_distribution_encoding<
|
||||
@@ -112,15 +112,27 @@ struct MHCProblemV4
|
||||
}
|
||||
|
||||
// Tile distribution for loading Phi (weight matrix) from global memory
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16)
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×16)
|
||||
// With 1 warp (64 threads), use 2 repeats in N to cover 32 elements:
|
||||
// N: 2 repeat × 1 warp × 8 threads × 2 vector = 32
|
||||
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
|
||||
// Total threads: 1 warp × 64 threads = 64 threads ✓
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// Same distribution as X for 16×16 tiles
|
||||
// H0 (N dimension): [repeat=2, warp=1, thread=8, vector=2] = 32
|
||||
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
|
||||
// P→RH: Warp layout = 1 warp in N × 1 warp in K = 1 warp total
|
||||
// Thread layout = 8 threads in N × 4 threads in K = 32 threads/warp... wait that's
|
||||
// only 32!
|
||||
// Need to recalculate: 8×4=32 threads, but we have 64 threads/warp
|
||||
// Better: N: 1 repeat × 1 warp × 16 threads × 2 vector = 32
|
||||
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
|
||||
// Thread layout: 16×4 = 64 threads ✓
|
||||
using PhiTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 1>, // H0 (N): repeat=1, warp=1, thread=16, vector=1
|
||||
tuple<sequence<1, 1, 16, 2>, // H0 (N): repeat=1, warp=1, thread=16, vector=2
|
||||
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
|
||||
|
||||
Reference in New Issue
Block a user