Remove hard coded lds size

This commit is contained in:
Damien Lejeune
2026-01-29 05:24:19 -05:00
parent b83c07748c
commit c83b1c482b
3 changed files with 90 additions and 84 deletions

View File

@@ -21,7 +21,9 @@ namespace ck_tile {
template <typename Problem_,
typename Policy_ = MHCDefaultPolicy,
index_t N_ = 4> // Template parameter for expansion factor (compile-time)
index_t B_ = 16, // Batch size (compile-time)
index_t N_ = 4, // Expansion factor (compile-time)
index_t C_ = 64> // Channels per stream (compile-time)
struct ManifoldConstrainedHyperConnectionTiled
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
@@ -32,8 +34,11 @@ struct ManifoldConstrainedHyperConnectionTiled
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using PhiDataType = ck_tile::remove_cvref_t<typename Problem::PhiDataType>;
static constexpr index_t kB = B_; // Batch size (compile-time)
static constexpr index_t kN = N_; // Expansion factor (compile-time)
static constexpr index_t kOutputDim = 2 * kN + kN * kN; // 2n + n² (compile-time)
static constexpr index_t kC = C_; // Channels per stream (compile-time)
static constexpr index_t kNC = kN * kC; // Input dimension (compile-time)
static constexpr index_t kOutputDim = 2 * kN + kN * kN; // Output dimension (compile-time)
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
@@ -45,22 +50,21 @@ struct ManifoldConstrainedHyperConnectionTiled
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// Need shared memory for reduction
return 256 * sizeof(ComputeDataType);
return kNC * sizeof(ComputeDataType);
}
CK_TILE_DEVICE void operator()(const XDataType* p_x, // [B, nC] - input tensor
const PhiDataType* p_phi, // [nC, 2n+n²] - packed weight matrices
YDataType* p_output, // [B, 2n+n²] - output tensor
int B, // batch size (now used!)
int C, // output layer dimension (runtime, potentially large)
float r = 1.0f, // scaling factor
float alpha_pre = 1.0f, // scaling for H^{pre}
float alpha_post = 1.0f, // scaling for H^{post}
float alpha_res = 1.0f, // scaling for H^{res}
float bias = 0.0f) const // bias term
float r = 1.0f, // scaling factor
float alpha_pre = 1.0f, // scaling for H^{pre}
float alpha_post = 1.0f, // scaling for H^{post}
float alpha_res = 1.0f, // scaling for H^{res}
float bias = 0.0f) const // bias term
{
const index_t nC = kN * C; // Use compile-time kN
constexpr index_t output_dim = kOutputDim; // Now compile-time!
constexpr index_t nC = kNC; // Compile-time!
constexpr index_t output_dim = kOutputDim; // Compile-time!
constexpr index_t B = kB; // Compile-time!
// NEW PARALLELIZATION STRATEGY:
// Each block processes 16 output columns starting at stream_id * 16
@@ -88,28 +92,28 @@ struct ManifoldConstrainedHyperConnectionTiled
// Step 1: Allocate LDS for x - need to load batches in tiles
// For BlockGemm, we need x[kBatchTile, nC] in LDS
// Process batches in chunks of kBatchTile (16 batches at a time)
__shared__ XDataType x_lds[kBatchTile * 256]; // Allocate for 16 batches × 256 elements
__shared__ XDataType x_lds[kBatchTile * kNC]; // Allocate for 16 batches × nC elements
// Step 2: Create phi infrastructure in LDS (shared across all batch tiles)
// For this stream, we need phi[:, stream_id:stream_id+16] which is [nC, 16]
// IMPORTANT: BlockGemm expects B matrix in K-major (column-major) layout!
// So phi_lds should be organized as [K_outer, N, K_inner] = [256/16, 16, 16]
// So phi_lds should be organized as [K_outer, N, K_inner] = [nC/16, 16, 16]
// with linear index: k_outer * (16 * 16) + n * 16 + k_inner
// This way, elements in the same column are contiguous (or nearly so with padding)
constexpr index_t kKPack = 16; // Pack size for K dimension
__shared__ PhiDataType phi_lds[256 * 16];
__shared__ PhiDataType phi_lds[kNC * 16];
// Load with K-major layout: iterate over K_outer, N, K_inner
for(index_t i = tid; i < 256 * 16; i += get_block_size())
for(index_t i = tid; i < kNC * 16; i += get_block_size())
{
// Decode linear index for K-major layout
index_t k_outer = i / (16 * kKPack); // 0-15 (256/16)
index_t remainder = i % (16 * kKPack);
index_t n = remainder / kKPack; // 0-15 (N dimension)
index_t k_inner = remainder % kKPack; // 0-15 (K pack)
index_t n_idx = remainder / kKPack; // 0-15 (N dimension) - renamed to avoid shadowing
index_t k_inner = remainder % kKPack; // 0-15 (K pack)
index_t global_k = k_outer * kKPack + k_inner; // Actual K index (0-255)
index_t global_n = stream_id + n; // Actual N index
index_t global_n = stream_id + n_idx; // Actual N index
if(global_k < nC && global_n < output_dim)
{
@@ -124,8 +128,9 @@ struct ManifoldConstrainedHyperConnectionTiled
block_sync_lds();
// Create phi tensor view with K-major layout
// Layout is [K_outer, N, K_inner] = [16, 16, 16] with appropriate strides
constexpr index_t kKOuter = 256 / kKPack; // 16
// Layout is [K_outer, N, K_inner] = [(kNC+15)/16, 16, 16] with appropriate strides
// Use ceiling division to handle non-divisible-by-16 cases
constexpr index_t kKOuter = (kNC + kKPack - 1) / kKPack;
const auto phi_lds_tensor_3d = make_naive_tensor_view<address_space_enum::lds>(
phi_lds,
make_tuple(number<kKOuter>{}, number<16>{}, number<kKPack>{}),
@@ -152,11 +157,11 @@ struct ManifoldConstrainedHyperConnectionTiled
const index_t current_batch_count = batch_end - batch_start;
// Step 3a: Load x from global to LDS for this batch tile
// Load current_batch_count batches, each with nC elements (but limit to 256 per batch)
for(index_t i = tid; i < kBatchTile * 256; i += get_block_size())
// Load current_batch_count batches, each with nC elements
for(index_t i = tid; i < kBatchTile * kNC; i += get_block_size())
{
index_t local_batch_idx = i / 256;
index_t elem_idx = i % 256;
index_t local_batch_idx = i / kNC;
index_t elem_idx = i % kNC;
index_t global_batch_idx = batch_start + local_batch_idx;
if(local_batch_idx < current_batch_count && elem_idx < nC)
@@ -173,8 +178,8 @@ struct ManifoldConstrainedHyperConnectionTiled
// Step 3b: Create LDS tensor view for x
const auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
x_lds,
make_tuple(number<kBatchTile>{}, number<256>{}),
make_tuple(number<256>{}, number<1>{}),
make_tuple(number<kBatchTile>{}, number<kNC>{}),
make_tuple(number<kNC>{}, number<1>{}),
number<1>{},
number<1>{});
@@ -195,8 +200,9 @@ struct ManifoldConstrainedHyperConnectionTiled
auto result_tile = BlockGemm::MakeCBlockTile();
set_tile(result_tile, 0.0f);
// Step 3f: Iterate over K dimension: nC=256, BlockGemm processes K=16 at a time
constexpr index_t num_k_tiles = 256 / 16; // 16 iterations for K=256
// Step 3f: Iterate over K dimension: nC, BlockGemm processes K=16 at a time
// Use ceiling division to handle non-divisible-by-16 cases
constexpr index_t num_k_tiles = (kNC + 15) / 16;
for(index_t k_tile = 0; k_tile < num_k_tiles; k_tile++)
{
// Move windows to next K tile
@@ -222,10 +228,10 @@ struct ManifoldConstrainedHyperConnectionTiled
ComputeDataType local_sum = 0.0f;
// Each thread accumulates part of the sum of squares
for(index_t k = tid; k < nC && k < 256; k += get_block_size())
for(index_t k = tid; k < nC; k += get_block_size())
{
ComputeDataType val =
type_convert<ComputeDataType>(x_lds[local_batch_idx * 256 + k]);
type_convert<ComputeDataType>(x_lds[local_batch_idx * kNC + k]);
local_sum += val * val;
}
@@ -263,26 +269,21 @@ struct ManifoldConstrainedHyperConnectionTiled
tile_elementwise_inout([&](auto& val) { val = (alpha / r) * val + bias; }, result_tile);
// Step 3i: Store result_tile to output
// We need to manually write the result since the batch offset might not align with tile
// boundaries Get the distributed spans for iteration
// Use sweep_tile_span for manual writes since runtime batch_start offset
// doesn't work well with make_tile_window
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
// Iterate over the tile and write each element to global memory
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) {
// Get actual tensor indices from distributed indices
const auto tile_idx = get_x_indices_from_distributed_indices(
result_tile.get_tile_distribution(), make_tuple(idx0, idx1));
// Extract integer indices
const index_t i_idx = tile_idx.at(number<0>{});
const index_t j_idx = tile_idx.at(number<1>{});
// Calculate global batch and column indices
const index_t global_batch = batch_start + i_idx;
const index_t global_col = stream_id + j_idx;
// Only write if within bounds
if(global_batch < B && global_col < output_dim)
{
constexpr auto i_j_idx = make_tuple(idx0, idx1);