diff --git a/example/ck_tile/42_mhc/mhc_v3.cpp b/example/ck_tile/42_mhc/mhc_v3.cpp index e71df68e0d..b5117c8d8e 100644 --- a/example/ck_tile/42_mhc/mhc_v3.cpp +++ b/example/ck_tile/42_mhc/mhc_v3.cpp @@ -48,9 +48,9 @@ int main() d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, + // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; diff --git a/example/ck_tile/42_mhc/mhc_v3_single_block_test.cpp b/example/ck_tile/42_mhc/mhc_v3_single_block_test.cpp index 920869d1ff..a553415534 100644 --- a/example/ck_tile/42_mhc/mhc_v3_single_block_test.cpp +++ b/example/ck_tile/42_mhc/mhc_v3_single_block_test.cpp @@ -49,9 +49,9 @@ int main() d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, + // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; diff --git a/example/ck_tile/42_mhc/mhc_v3_two_block_test.cpp b/example/ck_tile/42_mhc/mhc_v3_two_block_test.cpp index b15c22c0ce..5ef7dd8526 100644 --- a/example/ck_tile/42_mhc/mhc_v3_two_block_test.cpp +++ b/example/ck_tile/42_mhc/mhc_v3_two_block_test.cpp @@ -49,9 +49,9 @@ int main() d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, + // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp index 96b657a6a8..795414dc3e 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp @@ -6,7 +6,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" // Manifold Constrained Hyper Connection Kernel V3: @@ -21,7 +23,7 @@ namespace ck_tile { template struct MHCKernelV3 @@ -45,16 +47,23 @@ struct MHCKernelV3 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - // Calculate shared memory size based on BlockGemmShape - // The pipeline needs LDS for A[kM, kK] and B[kK, kN] + // Calculate LDS size for V1 pipeline + // V1 uses single-buffered LDS for A and B tiles constexpr index_t kM = Problem::BlockGemmShape::kM; constexpr index_t kN = Problem::BlockGemmShape::kN; constexpr index_t kK = Problem::BlockGemmShape::kK; - // Approximate LDS size (actual calculation is complex, but this is a safe upper bound) - constexpr index_t a_lds_size = kM * kK * sizeof(XDataType) * 2; - constexpr index_t b_lds_size = kN * kK * sizeof(PhiDataType) * 2; - return a_lds_size + b_lds_size; + constexpr index_t kLdsAlignmentInBytes = 16; + + // A LDS: [kM, kK] + constexpr index_t a_lds_size = kM * kK * sizeof(XDataType); + constexpr index_t a_lds_size_aligned = + ((a_lds_size + kLdsAlignmentInBytes - 1) / kLdsAlignmentInBytes) * kLdsAlignmentInBytes; + + // B LDS: [kN, kK] for column-major or [kK, kN] for row-major + constexpr index_t b_lds_size = kN * kK * sizeof(PhiDataType); + + return a_lds_size_aligned + b_lds_size; } // Grid configuration: 2D grid over (batch, output_dim) @@ -80,8 +89,9 @@ struct MHCKernelV3 { // 2D block indexing const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile; - const index_t block_m = get_block_id() / grid_n_size; - const index_t block_n = get_block_id() % grid_n_size; + const index_t block_id = get_block_id(); + const index_t block_m = block_id / grid_n_size; + const index_t block_n = block_id % grid_n_size; const index_t batch_start = block_m * kMTile; const index_t out_start = block_n * kNTile; @@ -89,54 +99,51 @@ struct MHCKernelV3 if(batch_start >= batch || out_start >= output_dim) return; - // Create tensor views with adjusted pointers and dimensions - // The GEMM pipeline expects windows with origin {0,0} relative to the tensor view - const index_t remaining_batch = batch - batch_start; - const index_t remaining_output = output_dim - out_start; + // Create full tensor views (not adjusted) and use window origins to select regions + auto x_tensor_full = make_naive_tensor_view( + p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{}); - auto x_tensor_unpadded = make_naive_tensor_view( - p_x + batch_start * nC, // Adjust pointer to start at this block's batch range - make_tuple(remaining_batch, nC), // Dimensions from this block's starting point - make_tuple(nC, 1), - number<1>{}, - number<1>{}); + // For column-major B [N, K], reinterpret row-major phi [nC, output_dim] + // as column-major [output_dim, nC] with strides [1, output_dim] + auto phi_tensor_full = make_naive_tensor_view( + p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{}); - auto phi_tensor_unpadded = make_naive_tensor_view( - p_phi + out_start, // Adjust pointer to start at this block's output range - make_tuple(nC, remaining_output), // Dimensions from this block's starting point - make_tuple(remaining_output, 1), - number<1>{}, - number<1>{}); + // Pad tensors according to GEMM pipeline requirements + // For row-major A [M, K]: pad with sequence + auto x_tensor_padded = + pad_tensor_view(x_tensor_full, + make_tuple(number{}, number{}), + sequence{}); // Don't pad M, conditionally pad K - // Pad tensors to tile sizes to handle boundary conditions - auto x_tensor = pad_tensor_view( - x_tensor_unpadded, make_tuple(number{}, number{}), sequence<0, 1>{}); + // For column-major B [N, K]: pad with sequence + auto phi_tensor_padded = + pad_tensor_view(phi_tensor_full, + make_tuple(number{}, number{}), + sequence{}); // Don't pad N, conditionally pad K - auto phi_tensor = pad_tensor_view( - phi_tensor_unpadded, make_tuple(number{}, number{}), sequence<0, 1>{}); - - // Create DRAM tile windows with origin {0, 0} relative to the padded tensor views - // The pipeline will internally manage K-dimension iteration + // Create DRAM tile windows from padded tensors auto x_dram_window = - make_tile_window(x_tensor, + make_tile_window(x_tensor_padded, make_tuple(number{}, number{}), - {0, 0}); // Origin at {0, 0} relative to the padded tensor view + {batch_start, 0}); // Start at this block's batch range auto phi_dram_window = - make_tile_window(phi_tensor, - make_tuple(number{}, number{}), - {0, 0}); // Origin at {0, 0} relative to the padded tensor view + make_tile_window(phi_tensor_padded, + make_tuple(number{}, number{}), + {out_start, 0}); // Start at this block's output range - // Use GEMM pipeline v3 to compute the full GEMM - using GemmPipeline = GemmPipelineAgBgCrCompV3; + // Use GEMM pipeline v1 to compute the full GEMM (more robust for multi-block execution) + using GemmPipeline = GemmPipelineAGmemBGmemCRegV1; const index_t num_k_loops = (nC + kKTile - 1) / kKTile; - extern __shared__ char smem[]; + // Use static shared memory allocation (per-block, not shared across blocks!) + __shared__ char smem[GetSmemSize()]; auto gemm_pipeline = GemmPipeline{}; - // V3 pipeline expects non-tuple windows and uses identity functions internally - auto result_tile = gemm_pipeline(x_dram_window, phi_dram_window, num_k_loops, smem); + // V1 pipeline expects tuple-wrapped windows + auto result_tile = gemm_pipeline( + make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem); // Apply elementwise operations (currently commented out for GEMM testing) constexpr auto result_spans = decltype(result_tile)::get_distributed_spans(); @@ -183,31 +190,28 @@ struct MHCKernelV3 // Cast result to output data type auto result_output = cast_tile(result_tile); - // Create output tensor view for efficient store_tile operation + // Create full output tensor view and use window origin constexpr index_t output_vector_size = 16 / sizeof(YDataType); - auto output_tensor_view_unpadded = make_naive_tensor_view( - p_output + batch_start * output_dim + - out_start, // Adjust pointer to this block's output region - make_tuple(remaining_batch, - remaining_output), // Dimensions from this block's starting point - make_tuple(output_dim, 1), // Strides: row-major layout - number{}, // Vector size for efficient memory access - number<1>{}); // Alignment + auto output_tensor_full = + make_naive_tensor_view(p_output, + make_tuple(batch, output_dim), + make_tuple(output_dim, 1), + number{}, + number<1>{}); - // Pad output tensor view to match the tile size (for boundary handling) - auto output_tensor_view = pad_tensor_view(output_tensor_view_unpadded, - make_tuple(number{}, number{}), - sequence<0, 1>{}); + // Pad output tensor view for boundary handling (row-major C: sequence) + auto output_tensor_padded = pad_tensor_view(output_tensor_full, + make_tuple(number{}, number{}), + sequence{}); - // Create tile window for the output using result_output's distribution - auto output_window = make_tile_window( - output_tensor_view, - make_tuple(number{}, number{}), - {0, 0}, // Origin at {0, 0} relative to the padded view - result_output.get_tile_distribution()); // Use distribution from result_output + // Create tile window with origin at this block's region + auto output_window = make_tile_window(output_tensor_padded, + make_tuple(number{}, number{}), + {batch_start, out_start}, + result_output.get_tile_distribution()); - // Store the result using the tile window (padding will prevent out-of-bounds writes) + // Store the result store_tile(output_window, result_output); } }; diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp index 445f00ee6f..0257acdb69 100644 --- a/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp +++ b/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp @@ -35,7 +35,8 @@ struct MHCProblem // Layout types for BlockGemm using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC] - using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim] + using BLayout = + ck_tile::tensor_layout::gemm::ColumnMajor; // phi treated as column-major for V1 pipeline using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major // For GEMM pipeline compatibility @@ -48,9 +49,9 @@ struct MHCProblem using BElementWise = identity; static constexpr bool TransposeC = false; - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; // TESTING: Disable N padding - static constexpr bool kPadK = false; + static constexpr bool kPadM = true; // Enable padding to help with boundary conditions + static constexpr bool kPadN = true; // Enable padding + static constexpr bool kPadK = true; // Enable padding static constexpr bool Preshuffle = false; static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; @@ -64,7 +65,7 @@ struct MHCProblem static constexpr index_t kBlockSize = BlockShape::BlockSize; // Additional traits required by v3 pipeline - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = true; // Enable double buffering for multi-block static constexpr bool UseStructuredSparsity = false; static constexpr bool FixedVectorSize = false;