diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp index 8c3ef489b3..e3a0c46d6f 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp @@ -23,7 +23,8 @@ template // Channels per stream (compile-time) + index_t C_ = 64, // Channels per stream (compile-time) + index_t KTile_ = 256> // K-tile size for shared memory (compile-time) struct ManifoldConstrainedHyperConnectionTiled { using Problem = ck_tile::remove_cvref_t; @@ -39,6 +40,7 @@ struct ManifoldConstrainedHyperConnectionTiled static constexpr index_t kC = C_; // Channels per stream (compile-time) static constexpr index_t kNC = kN * kC; // Input dimension (compile-time) static constexpr index_t kOutputDim = 2 * kN + kN * kN; // Output dimension (compile-time) + static constexpr index_t kKTile = KTile_; // K-tile size (compile-time) static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; @@ -49,8 +51,9 @@ struct ManifoldConstrainedHyperConnectionTiled CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - // Need shared memory for reduction - return kNC * sizeof(ComputeDataType); + // Shared memory is now bounded by kKTile instead of kNC + // This allows handling arbitrary C values + return kKTile * sizeof(ComputeDataType); } CK_TILE_DEVICE void operator()(const XDataType* p_x, // [B, nC] - input tensor @@ -83,68 +86,26 @@ struct ManifoldConstrainedHyperConnectionTiled constexpr index_t kBatchTile = 16; // Process 16 batches per tile const index_t num_batch_tiles = (B + kBatchTile - 1) / kBatchTile; - // With expansion-parallel strategy: + // Calculate number of K-tile iterations needed for large C values + // This allows us to handle arbitrary C by processing K in chunks + constexpr index_t num_ktile_iterations = (kNC + kKTile - 1) / kKTile; + + // With expansion-parallel strategy + K-tiling: // - Grid size = output_dim (one block per output column) // - Each block computes output[:, stream_id] for ALL batches // - GEMM becomes: x[B, nC] * phi[nC, 1] = output[B, 1] - // - This gives us M=B (e.g., 64), K=nC (e.g., 1024), N=1 + // - K dimension is tiled to fit in shared memory - // Step 1: Allocate LDS for x - need to load batches in tiles - // For BlockGemm, we need x[kBatchTile, nC] in LDS - // Process batches in chunks of kBatchTile (16 batches at a time) - __shared__ XDataType x_lds[kBatchTile * kNC]; // Allocate for 16 batches × nC elements + // Step 1: Allocate LDS for x - bounded by kKTile instead of kNC + // For BlockGemm, we need x[kBatchTile, kKTile] in LDS + __shared__ XDataType + x_lds[kBatchTile * kKTile]; // Allocate for 16 batches × kKTile elements - // Step 2: Create phi infrastructure in LDS (shared across all batch tiles) - // For this stream, we need phi[:, stream_id:stream_id+16] which is [nC, 16] + // Step 2: Create phi infrastructure in LDS - bounded by kKTile + // For this stream, we need phi[:, stream_id:stream_id+16] which is [kKTile, 16] // IMPORTANT: BlockGemm expects B matrix in K-major (column-major) layout! - // So phi_lds should be organized as [K_outer, N, K_inner] = [nC/16, 16, 16] - // with linear index: k_outer * (16 * 16) + n * 16 + k_inner - // This way, elements in the same column are contiguous (or nearly so with padding) constexpr index_t kKPack = 16; // Pack size for K dimension - __shared__ PhiDataType phi_lds[kNC * 16]; - - // Load with K-major layout: iterate over K_outer, N, K_inner - for(index_t i = tid; i < kNC * 16; i += get_block_size()) - { - // Decode linear index for K-major layout - index_t k_outer = i / (16 * kKPack); // 0-15 (256/16) - index_t remainder = i % (16 * kKPack); - index_t n_idx = remainder / kKPack; // 0-15 (N dimension) - renamed to avoid shadowing - index_t k_inner = remainder % kKPack; // 0-15 (K pack) - - index_t global_k = k_outer * kKPack + k_inner; // Actual K index (0-255) - index_t global_n = stream_id + n_idx; // Actual N index - - if(global_k < nC && global_n < output_dim) - { - // Load phi[global_k, global_n] from global memory - phi_lds[i] = p_phi[global_k * output_dim + global_n]; - } - else - { - phi_lds[i] = 0; // Pad - } - } - block_sync_lds(); - - // Create phi tensor view with K-major layout - // Layout is [K_outer, N, K_inner] = [(kNC+15)/16, 16, 16] with appropriate strides - // Use ceiling division to handle non-divisible-by-16 cases - constexpr index_t kKOuter = (kNC + kKPack - 1) / kKPack; - const auto phi_lds_tensor_3d = make_naive_tensor_view( - phi_lds, - make_tuple(number{}, number<16>{}, number{}), - make_tuple(number<16 * kKPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - // Transform to 2D [K, N] by merging K_outer and K_inner - const auto phi_lds_tensor = transform_tensor_view( - phi_lds_tensor_3d, - make_tuple(make_pass_through_transform(number<16>{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + __shared__ PhiDataType phi_lds[kKTile * 16]; using BlockGemm = BlockGemmASmemBSmemCRegV1; @@ -156,121 +117,115 @@ struct ManifoldConstrainedHyperConnectionTiled const index_t batch_end = min(batch_start + kBatchTile, B); const index_t current_batch_count = batch_end - batch_start; - // Step 3a: Load x from global to LDS for this batch tile - // Load current_batch_count batches, each with nC elements - for(index_t i = tid; i < kBatchTile * kNC; i += get_block_size()) - { - index_t local_batch_idx = i / kNC; - index_t elem_idx = i % kNC; - index_t global_batch_idx = batch_start + local_batch_idx; - - if(local_batch_idx < current_batch_count && elem_idx < nC) - { - x_lds[i] = p_x[global_batch_idx * nC + elem_idx]; - } - else - { - x_lds[i] = 0; // Pad with zeros for out-of-bounds - } - } - block_sync_lds(); - - // Step 3b: Create LDS tensor view for x - const auto x_lds_tensor = make_naive_tensor_view( - x_lds, - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number<1>{}, - number<1>{}); - - // Step 3c: Create tile window for x in LDS - // BlockGemm expects [kM, kK] = [16, 16] - auto x_lds_window = make_tile_window( - x_lds_tensor, - make_tuple(number<16>{}, number<16>{}), // [M=16, K=16] to match BlockGemmShape - {0, 0}); // origin - - // Step 3d: Create phi tile window (reset for each batch tile) - auto phi_lds_window = make_tile_window( - phi_lds_tensor, - make_tuple(number<16>{}, number<16>{}), // [K=16, N=16] to match BlockGemmShape - {0, 0}); - - // Step 3e: Initialize result tile to zero for this batch tile + // Step 3a: Initialize result tile to zero for this batch tile auto result_tile = BlockGemm::MakeCBlockTile(); set_tile(result_tile, 0.0f); - // Step 3f: Iterate over K dimension: nC, BlockGemm processes K=16 at a time - // Use ceiling division to handle non-divisible-by-16 cases - constexpr index_t num_k_tiles = (kNC + 15) / 16; - for(index_t k_tile = 0; k_tile < num_k_tiles; k_tile++) + // Step 3b: Iterate over K-tiles (outer loop for large C values) + for(index_t ktile_idx = 0; ktile_idx < num_ktile_iterations; ktile_idx++) { - // Move windows to next K tile - if(k_tile > 0) + // Calculate K range for this tile + const index_t k_start = ktile_idx * kKTile; + const index_t k_end = min(k_start + kKTile, nC); + const index_t current_k_len = k_end - k_start; + + // Step 3b-i: Load x from global to LDS for this batch tile and K-tile + for(index_t i = tid; i < kBatchTile * kKTile; i += get_block_size()) { - move_tile_window(x_lds_window, {0, 16}); // Move K dimension - move_tile_window(phi_lds_window, {16, 0}); // Move K dimension + index_t local_batch_idx = i / kKTile; + index_t local_k_idx = i % kKTile; + index_t global_batch_idx = batch_start + local_batch_idx; + index_t global_k_idx = k_start + local_k_idx; + + if(local_batch_idx < current_batch_count && local_k_idx < current_k_len) + { + x_lds[i] = p_x[global_batch_idx * nC + global_k_idx]; + } + else + { + x_lds[i] = 0; // Pad with zeros for out-of-bounds + } } - // Accumulate: result_tile += x_lds_window * phi_lds_window - BlockGemm{}(result_tile, x_lds_window, phi_lds_window); - } - - // Step 3g: Compute norm ||x_l||_2 / sqrt(nC) for each batch in this tile - // We need this for potential normalization - // Allocate shared memory for norm computation - __shared__ ComputeDataType norm_shared[kBatchTile]; - - // Compute norm for each batch in this tile - for(index_t local_batch_idx = 0; local_batch_idx < current_batch_count; - local_batch_idx++) - { - ComputeDataType local_sum = 0.0f; - - // Each thread accumulates part of the sum of squares - for(index_t k = tid; k < nC; k += get_block_size()) + // Step 3b-ii: Load phi from global to LDS for this K-tile + // Load with K-major layout for optimal BlockGemm performance + // Layout: [K_outer, N, K_inner] where K_outer * K_inner = kKTile + for(index_t i = tid; i < kKTile * 16; i += get_block_size()) { - ComputeDataType val = - type_convert(x_lds[local_batch_idx * kNC + k]); - local_sum += val * val; - } + // Decode linear index for K-major layout + index_t k_outer_local = i / (16 * kKPack); + index_t remainder = i % (16 * kKPack); + index_t n_idx = remainder / kKPack; + index_t k_inner = remainder % kKPack; - // Simple reduction (can be optimized with block_reduce) - // For now, use atomic add to shared memory - if(tid == 0) - { - norm_shared[local_batch_idx] = 0; + index_t local_k = k_outer_local * kKPack + k_inner; + index_t global_k = k_start + local_k; + index_t global_n = stream_id + n_idx; + + if(local_k < current_k_len && global_n < output_dim) + { + phi_lds[i] = p_phi[global_k * output_dim + global_n]; + } + else + { + phi_lds[i] = 0; // Pad + } } block_sync_lds(); - // Accumulate (simplified - should use proper reduction) - if(local_sum > 0) + // Step 3b-iii: Create LDS tensor views for this K-tile + const auto x_lds_tensor = make_naive_tensor_view( + x_lds, + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + // Create phi tensor view with K-major layout + constexpr index_t kKOuter_tile = (kKTile + kKPack - 1) / kKPack; + const auto phi_lds_tensor_3d = make_naive_tensor_view( + phi_lds, + make_tuple(number{}, number<16>{}, number{}), + make_tuple(number<16 * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + const auto phi_lds_tensor = transform_tensor_view( + phi_lds_tensor_3d, + make_tuple( + make_pass_through_transform(number<16>{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // Step 3b-iv: Create tile windows + auto x_lds_window = + make_tile_window(x_lds_tensor, make_tuple(number<16>{}, number<16>{}), {0, 0}); + + auto phi_lds_window = make_tile_window( + phi_lds_tensor, make_tuple(number<16>{}, number<16>{}), {0, 0}); + + // Step 3b-v: Iterate over 16x16 tiles within this K-tile + constexpr index_t num_inner_k_tiles = (kKTile + 15) / 16; + for(index_t inner_k_tile = 0; inner_k_tile < num_inner_k_tiles; inner_k_tile++) { - atomicAdd(&norm_shared[local_batch_idx], local_sum); + if(inner_k_tile > 0) + { + move_tile_window(x_lds_window, {0, 16}); + move_tile_window(phi_lds_window, {16, 0}); + } + + // Accumulate: result_tile += x_lds_window * phi_lds_window + BlockGemm{}(result_tile, x_lds_window, phi_lds_window); } + block_sync_lds(); + } // End K-tile loop - // Compute final norm - if(tid == 0) - { - norm_shared[local_batch_idx] = sqrt(norm_shared[local_batch_idx]) / - sqrt(type_convert(nC)); - } - block_sync_lds(); - } - - // Step 3h: Apply elementwise operations to result_tile using tile_elementwise_inout - // Determine which section this stream belongs to and get alpha - float alpha = (stream_id < kN) ? alpha_pre - : (stream_id < 2 * kN) ? alpha_post - : alpha_res; - - // Apply scaling: result = (alpha / r) * result + bias - tile_elementwise_inout([&](auto& val) { val = (alpha / r) * val + bias; }, result_tile); - - // Step 3i: Store result_tile to output - // Use sweep_tile_span for manual writes since runtime batch_start offset - // doesn't work well with make_tile_window + // Step 3h & 3i: Apply elementwise operations and store result_tile to output + // We need to apply different alpha values based on which output column each element + // belongs to Since result_tile contains columns [stream_id, stream_id+16), we apply + // alpha during store constexpr auto result_spans = decltype(result_tile)::get_distributed_spans(); sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) { @@ -286,9 +241,16 @@ struct ManifoldConstrainedHyperConnectionTiled if(global_batch < B && global_col < output_dim) { + // Determine alpha based on the actual output column + float alpha = (global_col < kN) ? alpha_pre + : (global_col < 2 * kN) ? alpha_post + : alpha_res; + + // Apply scaling and bias, then store: result = (alpha / r) * result + bias constexpr auto i_j_idx = make_tuple(idx0, idx1); const index_t global_idx = global_batch * output_dim + global_col; - p_output[global_idx] = type_convert(result_tile[i_j_idx]); + p_output[global_idx] = + type_convert((alpha / r) * result_tile[i_j_idx] + bias); } }); }); diff --git a/test/ck_tile/mhc/test_mhc.cpp b/test/ck_tile/mhc/test_mhc.cpp index 869d065202..6c622143f6 100644 --- a/test/ck_tile/mhc/test_mhc.cpp +++ b/test/ck_tile/mhc/test_mhc.cpp @@ -55,3 +55,15 @@ TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunBatchSizeTe TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunBatchSizeTest<32, 3, 85>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunBatchSizeTest<8, 8, 32>(); } + +// Test with large C values that require K-tiling (nC > 256) +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunBatchSizeTest<2, 4, 512>(); } + +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunBatchSizeTest<2, 4, 1024>(); } + +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunBatchSizeTest<2, 4, 4096>(); } + +TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) +{ + this->template RunBatchSizeTest<16, 4, 4096>(); +} diff --git a/test/ck_tile/mhc/test_mhc_impl.hpp b/test/ck_tile/mhc/test_mhc_impl.hpp index afcc189887..60ceae983b 100644 --- a/test/ck_tile/mhc/test_mhc_impl.hpp +++ b/test/ck_tile/mhc/test_mhc_impl.hpp @@ -319,7 +319,8 @@ class TestCkTileMHC : public ::testing::Test const ck_tile::index_t kGridSize = (output_dim + 15) / 16; constexpr ck_tile::index_t kBlockPerCu = 1; - const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = 0.0f; + // const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = 0.0f; + const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f; // Launch kernel (B, n, C are now template parameters, not runtime) ck_tile::launch_kernel(