mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Tile on the C dimensions to support large C
This commit is contained in:
@@ -23,7 +23,8 @@ template <typename Problem_,
|
||||
typename Policy_ = MHCDefaultPolicy,
|
||||
index_t B_ = 16, // Batch size (compile-time)
|
||||
index_t N_ = 4, // Expansion factor (compile-time)
|
||||
index_t C_ = 64> // Channels per stream (compile-time)
|
||||
index_t C_ = 64, // Channels per stream (compile-time)
|
||||
index_t KTile_ = 256> // K-tile size for shared memory (compile-time)
|
||||
struct ManifoldConstrainedHyperConnectionTiled
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
@@ -39,6 +40,7 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
static constexpr index_t kC = C_; // Channels per stream (compile-time)
|
||||
static constexpr index_t kNC = kN * kC; // Input dimension (compile-time)
|
||||
static constexpr index_t kOutputDim = 2 * kN + kN * kN; // Output dimension (compile-time)
|
||||
static constexpr index_t kKTile = KTile_; // K-tile size (compile-time)
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -49,8 +51,9 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// Need shared memory for reduction
|
||||
return kNC * sizeof(ComputeDataType);
|
||||
// Shared memory is now bounded by kKTile instead of kNC
|
||||
// This allows handling arbitrary C values
|
||||
return kKTile * sizeof(ComputeDataType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, // [B, nC] - input tensor
|
||||
@@ -83,68 +86,26 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
constexpr index_t kBatchTile = 16; // Process 16 batches per tile
|
||||
const index_t num_batch_tiles = (B + kBatchTile - 1) / kBatchTile;
|
||||
|
||||
// With expansion-parallel strategy:
|
||||
// Calculate number of K-tile iterations needed for large C values
|
||||
// This allows us to handle arbitrary C by processing K in chunks
|
||||
constexpr index_t num_ktile_iterations = (kNC + kKTile - 1) / kKTile;
|
||||
|
||||
// With expansion-parallel strategy + K-tiling:
|
||||
// - Grid size = output_dim (one block per output column)
|
||||
// - Each block computes output[:, stream_id] for ALL batches
|
||||
// - GEMM becomes: x[B, nC] * phi[nC, 1] = output[B, 1]
|
||||
// - This gives us M=B (e.g., 64), K=nC (e.g., 1024), N=1
|
||||
// - K dimension is tiled to fit in shared memory
|
||||
|
||||
// Step 1: Allocate LDS for x - need to load batches in tiles
|
||||
// For BlockGemm, we need x[kBatchTile, nC] in LDS
|
||||
// Process batches in chunks of kBatchTile (16 batches at a time)
|
||||
__shared__ XDataType x_lds[kBatchTile * kNC]; // Allocate for 16 batches × nC elements
|
||||
// Step 1: Allocate LDS for x - bounded by kKTile instead of kNC
|
||||
// For BlockGemm, we need x[kBatchTile, kKTile] in LDS
|
||||
__shared__ XDataType
|
||||
x_lds[kBatchTile * kKTile]; // Allocate for 16 batches × kKTile elements
|
||||
|
||||
// Step 2: Create phi infrastructure in LDS (shared across all batch tiles)
|
||||
// For this stream, we need phi[:, stream_id:stream_id+16] which is [nC, 16]
|
||||
// Step 2: Create phi infrastructure in LDS - bounded by kKTile
|
||||
// For this stream, we need phi[:, stream_id:stream_id+16] which is [kKTile, 16]
|
||||
// IMPORTANT: BlockGemm expects B matrix in K-major (column-major) layout!
|
||||
// So phi_lds should be organized as [K_outer, N, K_inner] = [nC/16, 16, 16]
|
||||
// with linear index: k_outer * (16 * 16) + n * 16 + k_inner
|
||||
// This way, elements in the same column are contiguous (or nearly so with padding)
|
||||
constexpr index_t kKPack = 16; // Pack size for K dimension
|
||||
__shared__ PhiDataType phi_lds[kNC * 16];
|
||||
|
||||
// Load with K-major layout: iterate over K_outer, N, K_inner
|
||||
for(index_t i = tid; i < kNC * 16; i += get_block_size())
|
||||
{
|
||||
// Decode linear index for K-major layout
|
||||
index_t k_outer = i / (16 * kKPack); // 0-15 (256/16)
|
||||
index_t remainder = i % (16 * kKPack);
|
||||
index_t n_idx = remainder / kKPack; // 0-15 (N dimension) - renamed to avoid shadowing
|
||||
index_t k_inner = remainder % kKPack; // 0-15 (K pack)
|
||||
|
||||
index_t global_k = k_outer * kKPack + k_inner; // Actual K index (0-255)
|
||||
index_t global_n = stream_id + n_idx; // Actual N index
|
||||
|
||||
if(global_k < nC && global_n < output_dim)
|
||||
{
|
||||
// Load phi[global_k, global_n] from global memory
|
||||
phi_lds[i] = p_phi[global_k * output_dim + global_n];
|
||||
}
|
||||
else
|
||||
{
|
||||
phi_lds[i] = 0; // Pad
|
||||
}
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Create phi tensor view with K-major layout
|
||||
// Layout is [K_outer, N, K_inner] = [(kNC+15)/16, 16, 16] with appropriate strides
|
||||
// Use ceiling division to handle non-divisible-by-16 cases
|
||||
constexpr index_t kKOuter = (kNC + kKPack - 1) / kKPack;
|
||||
const auto phi_lds_tensor_3d = make_naive_tensor_view<address_space_enum::lds>(
|
||||
phi_lds,
|
||||
make_tuple(number<kKOuter>{}, number<16>{}, number<kKPack>{}),
|
||||
make_tuple(number<16 * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
// Transform to 2D [K, N] by merging K_outer and K_inner
|
||||
const auto phi_lds_tensor = transform_tensor_view(
|
||||
phi_lds_tensor_3d,
|
||||
make_tuple(make_pass_through_transform(number<16>{}),
|
||||
make_merge_transform(make_tuple(number<kKOuter>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
__shared__ PhiDataType phi_lds[kKTile * 16];
|
||||
|
||||
using BlockGemm = BlockGemmASmemBSmemCRegV1<Problem, Policy>;
|
||||
|
||||
@@ -156,121 +117,115 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
const index_t batch_end = min(batch_start + kBatchTile, B);
|
||||
const index_t current_batch_count = batch_end - batch_start;
|
||||
|
||||
// Step 3a: Load x from global to LDS for this batch tile
|
||||
// Load current_batch_count batches, each with nC elements
|
||||
for(index_t i = tid; i < kBatchTile * kNC; i += get_block_size())
|
||||
{
|
||||
index_t local_batch_idx = i / kNC;
|
||||
index_t elem_idx = i % kNC;
|
||||
index_t global_batch_idx = batch_start + local_batch_idx;
|
||||
|
||||
if(local_batch_idx < current_batch_count && elem_idx < nC)
|
||||
{
|
||||
x_lds[i] = p_x[global_batch_idx * nC + elem_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
x_lds[i] = 0; // Pad with zeros for out-of-bounds
|
||||
}
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Step 3b: Create LDS tensor view for x
|
||||
const auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
x_lds,
|
||||
make_tuple(number<kBatchTile>{}, number<kNC>{}),
|
||||
make_tuple(number<kNC>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
// Step 3c: Create tile window for x in LDS
|
||||
// BlockGemm expects [kM, kK] = [16, 16]
|
||||
auto x_lds_window = make_tile_window(
|
||||
x_lds_tensor,
|
||||
make_tuple(number<16>{}, number<16>{}), // [M=16, K=16] to match BlockGemmShape
|
||||
{0, 0}); // origin
|
||||
|
||||
// Step 3d: Create phi tile window (reset for each batch tile)
|
||||
auto phi_lds_window = make_tile_window(
|
||||
phi_lds_tensor,
|
||||
make_tuple(number<16>{}, number<16>{}), // [K=16, N=16] to match BlockGemmShape
|
||||
{0, 0});
|
||||
|
||||
// Step 3e: Initialize result tile to zero for this batch tile
|
||||
// Step 3a: Initialize result tile to zero for this batch tile
|
||||
auto result_tile = BlockGemm::MakeCBlockTile();
|
||||
set_tile(result_tile, 0.0f);
|
||||
|
||||
// Step 3f: Iterate over K dimension: nC, BlockGemm processes K=16 at a time
|
||||
// Use ceiling division to handle non-divisible-by-16 cases
|
||||
constexpr index_t num_k_tiles = (kNC + 15) / 16;
|
||||
for(index_t k_tile = 0; k_tile < num_k_tiles; k_tile++)
|
||||
// Step 3b: Iterate over K-tiles (outer loop for large C values)
|
||||
for(index_t ktile_idx = 0; ktile_idx < num_ktile_iterations; ktile_idx++)
|
||||
{
|
||||
// Move windows to next K tile
|
||||
if(k_tile > 0)
|
||||
// Calculate K range for this tile
|
||||
const index_t k_start = ktile_idx * kKTile;
|
||||
const index_t k_end = min(k_start + kKTile, nC);
|
||||
const index_t current_k_len = k_end - k_start;
|
||||
|
||||
// Step 3b-i: Load x from global to LDS for this batch tile and K-tile
|
||||
for(index_t i = tid; i < kBatchTile * kKTile; i += get_block_size())
|
||||
{
|
||||
move_tile_window(x_lds_window, {0, 16}); // Move K dimension
|
||||
move_tile_window(phi_lds_window, {16, 0}); // Move K dimension
|
||||
index_t local_batch_idx = i / kKTile;
|
||||
index_t local_k_idx = i % kKTile;
|
||||
index_t global_batch_idx = batch_start + local_batch_idx;
|
||||
index_t global_k_idx = k_start + local_k_idx;
|
||||
|
||||
if(local_batch_idx < current_batch_count && local_k_idx < current_k_len)
|
||||
{
|
||||
x_lds[i] = p_x[global_batch_idx * nC + global_k_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
x_lds[i] = 0; // Pad with zeros for out-of-bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate: result_tile += x_lds_window * phi_lds_window
|
||||
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
|
||||
}
|
||||
|
||||
// Step 3g: Compute norm ||x_l||_2 / sqrt(nC) for each batch in this tile
|
||||
// We need this for potential normalization
|
||||
// Allocate shared memory for norm computation
|
||||
__shared__ ComputeDataType norm_shared[kBatchTile];
|
||||
|
||||
// Compute norm for each batch in this tile
|
||||
for(index_t local_batch_idx = 0; local_batch_idx < current_batch_count;
|
||||
local_batch_idx++)
|
||||
{
|
||||
ComputeDataType local_sum = 0.0f;
|
||||
|
||||
// Each thread accumulates part of the sum of squares
|
||||
for(index_t k = tid; k < nC; k += get_block_size())
|
||||
// Step 3b-ii: Load phi from global to LDS for this K-tile
|
||||
// Load with K-major layout for optimal BlockGemm performance
|
||||
// Layout: [K_outer, N, K_inner] where K_outer * K_inner = kKTile
|
||||
for(index_t i = tid; i < kKTile * 16; i += get_block_size())
|
||||
{
|
||||
ComputeDataType val =
|
||||
type_convert<ComputeDataType>(x_lds[local_batch_idx * kNC + k]);
|
||||
local_sum += val * val;
|
||||
}
|
||||
// Decode linear index for K-major layout
|
||||
index_t k_outer_local = i / (16 * kKPack);
|
||||
index_t remainder = i % (16 * kKPack);
|
||||
index_t n_idx = remainder / kKPack;
|
||||
index_t k_inner = remainder % kKPack;
|
||||
|
||||
// Simple reduction (can be optimized with block_reduce)
|
||||
// For now, use atomic add to shared memory
|
||||
if(tid == 0)
|
||||
{
|
||||
norm_shared[local_batch_idx] = 0;
|
||||
index_t local_k = k_outer_local * kKPack + k_inner;
|
||||
index_t global_k = k_start + local_k;
|
||||
index_t global_n = stream_id + n_idx;
|
||||
|
||||
if(local_k < current_k_len && global_n < output_dim)
|
||||
{
|
||||
phi_lds[i] = p_phi[global_k * output_dim + global_n];
|
||||
}
|
||||
else
|
||||
{
|
||||
phi_lds[i] = 0; // Pad
|
||||
}
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Accumulate (simplified - should use proper reduction)
|
||||
if(local_sum > 0)
|
||||
// Step 3b-iii: Create LDS tensor views for this K-tile
|
||||
const auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
x_lds,
|
||||
make_tuple(number<kBatchTile>{}, number<kKTile>{}),
|
||||
make_tuple(number<kKTile>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
// Create phi tensor view with K-major layout
|
||||
constexpr index_t kKOuter_tile = (kKTile + kKPack - 1) / kKPack;
|
||||
const auto phi_lds_tensor_3d = make_naive_tensor_view<address_space_enum::lds>(
|
||||
phi_lds,
|
||||
make_tuple(number<kKOuter_tile>{}, number<16>{}, number<kKPack>{}),
|
||||
make_tuple(number<16 * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto phi_lds_tensor = transform_tensor_view(
|
||||
phi_lds_tensor_3d,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<16>{}),
|
||||
make_merge_transform(make_tuple(number<kKOuter_tile>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Step 3b-iv: Create tile windows
|
||||
auto x_lds_window =
|
||||
make_tile_window(x_lds_tensor, make_tuple(number<16>{}, number<16>{}), {0, 0});
|
||||
|
||||
auto phi_lds_window = make_tile_window(
|
||||
phi_lds_tensor, make_tuple(number<16>{}, number<16>{}), {0, 0});
|
||||
|
||||
// Step 3b-v: Iterate over 16x16 tiles within this K-tile
|
||||
constexpr index_t num_inner_k_tiles = (kKTile + 15) / 16;
|
||||
for(index_t inner_k_tile = 0; inner_k_tile < num_inner_k_tiles; inner_k_tile++)
|
||||
{
|
||||
atomicAdd(&norm_shared[local_batch_idx], local_sum);
|
||||
if(inner_k_tile > 0)
|
||||
{
|
||||
move_tile_window(x_lds_window, {0, 16});
|
||||
move_tile_window(phi_lds_window, {16, 0});
|
||||
}
|
||||
|
||||
// Accumulate: result_tile += x_lds_window * phi_lds_window
|
||||
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
} // End K-tile loop
|
||||
|
||||
// Compute final norm
|
||||
if(tid == 0)
|
||||
{
|
||||
norm_shared[local_batch_idx] = sqrt(norm_shared[local_batch_idx]) /
|
||||
sqrt(type_convert<ComputeDataType>(nC));
|
||||
}
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Step 3h: Apply elementwise operations to result_tile using tile_elementwise_inout
|
||||
// Determine which section this stream belongs to and get alpha
|
||||
float alpha = (stream_id < kN) ? alpha_pre
|
||||
: (stream_id < 2 * kN) ? alpha_post
|
||||
: alpha_res;
|
||||
|
||||
// Apply scaling: result = (alpha / r) * result + bias
|
||||
tile_elementwise_inout([&](auto& val) { val = (alpha / r) * val + bias; }, result_tile);
|
||||
|
||||
// Step 3i: Store result_tile to output
|
||||
// Use sweep_tile_span for manual writes since runtime batch_start offset
|
||||
// doesn't work well with make_tile_window
|
||||
// Step 3h & 3i: Apply elementwise operations and store result_tile to output
|
||||
// We need to apply different alpha values based on which output column each element
|
||||
// belongs to Since result_tile contains columns [stream_id, stream_id+16), we apply
|
||||
// alpha during store
|
||||
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -286,9 +241,16 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
|
||||
if(global_batch < B && global_col < output_dim)
|
||||
{
|
||||
// Determine alpha based on the actual output column
|
||||
float alpha = (global_col < kN) ? alpha_pre
|
||||
: (global_col < 2 * kN) ? alpha_post
|
||||
: alpha_res;
|
||||
|
||||
// Apply scaling and bias, then store: result = (alpha / r) * result + bias
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const index_t global_idx = global_batch * output_dim + global_col;
|
||||
p_output[global_idx] = type_convert<YDataType>(result_tile[i_j_idx]);
|
||||
p_output[global_idx] =
|
||||
type_convert<YDataType>((alpha / r) * result_tile[i_j_idx] + bias);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user