V5: reintroduce k-loop + adaptive k-tile size

This commit is contained in:
Damien Lejeune
2026-02-12 13:54:58 +00:00
parent 5fe7632393
commit 0d7a341d27

View File

@@ -42,6 +42,17 @@ struct MHCKernelV5
static constexpr index_t kBlockSize = Problem::kBlockSize;
// Adaptive K-tiles per block based on C dimension
CK_TILE_HOST_DEVICE static constexpr index_t GetKTilesPerBlock(index_t nC)
{
// Adaptive selection based on C size:
// - Large C (≥4096): 8 tiles/block (512 elements) - maximize MFMA utilization
// - Medium C (≥1024): 4 tiles/block (256 elements) - balance overhead and work
// - Small C (≥256): 2 tiles/block (128 elements) - reduce overhead
// - Tiny C (<256): 1 tile/block (64 elements) - minimize overhead
return (nC >= 4096) ? 8 : (nC >= 1024) ? 4 : (nC >= 256) ? 2 : 1;
}
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
// Padding to avoid LDS bank conflicts
@@ -58,11 +69,14 @@ struct MHCKernelV5
}
// Grid configuration: 2D grid (B, C) for split-K
// Each block processes adaptive number of K-tiles (hierarchical split-K)
CK_TILE_HOST static constexpr auto
GetGridSize(index_t batch, [[maybe_unused]] index_t output_dim, index_t nC)
{
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t grid_k = (nC + kKTile - 1) / kKTile; // Split-K dimension
const index_t k_tiles_per_block = GetKTilesPerBlock(nC);
const index_t k_per_block = kKTile * k_tiles_per_block;
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t grid_k = (nC + k_per_block - 1) / k_per_block;
return make_tuple(grid_m, grid_k);
}
@@ -80,13 +94,18 @@ struct MHCKernelV5
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f) const
{
// Determine adaptive K-tiles per block based on C dimension
const index_t k_tiles_per_block = GetKTilesPerBlock(nC);
const index_t k_per_block = kKTile * k_tiles_per_block;
// 2D block indexing
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t block_m = get_block_id() % grid_m;
const index_t block_k = get_block_id() / grid_m;
const index_t batch_start = block_m * kMTile;
const index_t k_start = block_k * kKTile;
const index_t k_start = block_k * k_per_block; // Start of this block's K-range
const index_t k_end = ck_tile::min(k_start + k_per_block, nC);
const index_t out_start = 0;
if(batch_start >= batch || k_start >= nC)
@@ -103,9 +122,6 @@ struct MHCKernelV5
auto result_tile = BlockGemm::MakeCBlockTile();
set_tile(result_tile, 0.0f);
// Determine actual K size for this block
const index_t k_size = ck_tile::min(kKTile, nC - k_start);
// Create tensor views for X and Phi
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
@@ -154,7 +170,7 @@ struct MHCKernelV5
auto phi_lds_window = make_tile_window(
phi_lds_tensor, make_tuple(number<kNTile>{}, number<kKTile>{}), {0, 0});
// Compute partial norms for this K-tile
// Compute partial norms for this block's K-range (all K-tiles)
const index_t thread_id = get_thread_id();
constexpr index_t threads_per_row = kBlockSize / kMTile;
const index_t row_id = thread_id / threads_per_row;
@@ -170,12 +186,13 @@ struct MHCKernelV5
if(global_m < batch)
{
const XDataType* row_ptr = p_x + global_m * nC + k_start;
const index_t k_range = k_end - k_start;
constexpr index_t kVecSize = 4;
for(index_t k = thread_in_row * kVecSize; k < k_size;
for(index_t k = thread_in_row * kVecSize; k < k_range;
k += threads_per_row * kVecSize)
{
if(k + kVecSize <= k_size)
if(k + kVecSize <= k_range)
{
using VecType = ext_vector_t<XDataType, kVecSize>;
VecType vec = *c_style_pointer_cast<const VecType*>(row_ptr + k);
@@ -189,7 +206,7 @@ struct MHCKernelV5
}
else
{
for(index_t i = 0; i < kVecSize && k + i < k_size; ++i)
for(index_t i = 0; i < kVecSize && k + i < k_range; ++i)
{
ComputeDataType val = type_convert<ComputeDataType>(row_ptr[k + i]);
partial_sum += val * val;
@@ -222,23 +239,34 @@ struct MHCKernelV5
}
}
// Load X tile for this K-slice
auto x_tile = make_static_distributed_tensor<XDataType>(x_load_tile_dist);
load_tile(x_tile, x_dram_window);
store_tile(x_lds_window, x_tile);
// Loop over K-tiles within this block's K-range (adaptive count)
for(index_t k_tile_idx = 0; k_tile_idx < k_tiles_per_block; ++k_tile_idx)
{
const index_t k_current = k_start + k_tile_idx * kKTile;
if(k_current >= k_end)
break;
// Load Phi tile for this K-slice
auto phi_tile = make_static_distributed_tensor<PhiDataType>(phi_load_tile_dist);
load_tile(phi_tile, phi_dram_window);
store_tile(phi_lds_window, phi_tile);
// Load X tile for this K-slice
auto x_tile = make_static_distributed_tensor<XDataType>(x_load_tile_dist);
load_tile(x_tile, x_dram_window);
store_tile(x_lds_window, x_tile);
block_sync_lds();
// Load Phi tile for this K-slice
auto phi_tile = make_static_distributed_tensor<PhiDataType>(phi_load_tile_dist);
load_tile(phi_tile, phi_dram_window);
store_tile(phi_lds_window, phi_tile);
// Perform GEMM for this K-slice: result_tile = x_lds * phi_lds^T
// Note: This is a partial result for just this K-tile
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
block_sync_lds();
block_sync_lds();
// Accumulate GEMM: result_tile += x_lds * phi_lds^T
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
// Move windows to next K-tile
move_tile_window(x_dram_window, {0, kKTile});
move_tile_window(phi_dram_window, {0, kKTile});
block_sync_lds();
}
// Store partial results to workspace buffer: p_workspace[block_k, batch, output_dim]
// Layout: [grid_k][batch][output_dim]