MHC V3 with gemm pipeline

This commit is contained in:
Damien Lejeune
2026-02-05 17:11:09 +00:00
parent 43a5678fdf
commit 053aed9402
5 changed files with 83 additions and 78 deletions

View File

@@ -6,7 +6,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Manifold Constrained Hyper Connection Kernel V3:
@@ -21,7 +23,7 @@ namespace ck_tile {
template <typename Problem_,
typename Policy_ = MHCDefaultPolicy,
index_t kMTile_ = 64, // Batch tile size
index_t kNTile_ = 32, // Output dimension tile (can cover all 26 outputs)
index_t kNTile_ = 32, // Output dimension tile (can cover all 24 outputs)
index_t kKTile_ = 8, // K-tile for C dimension (must match BlockGemmShape::kK)
typename Activation_ = element_wise::Sigmoid>
struct MHCKernelV3
@@ -45,16 +47,23 @@ struct MHCKernelV3
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// Calculate shared memory size based on BlockGemmShape
// The pipeline needs LDS for A[kM, kK] and B[kK, kN]
// Calculate LDS size for V1 pipeline
// V1 uses single-buffered LDS for A and B tiles
constexpr index_t kM = Problem::BlockGemmShape::kM;
constexpr index_t kN = Problem::BlockGemmShape::kN;
constexpr index_t kK = Problem::BlockGemmShape::kK;
// Approximate LDS size (actual calculation is complex, but this is a safe upper bound)
constexpr index_t a_lds_size = kM * kK * sizeof(XDataType) * 2;
constexpr index_t b_lds_size = kN * kK * sizeof(PhiDataType) * 2;
return a_lds_size + b_lds_size;
constexpr index_t kLdsAlignmentInBytes = 16;
// A LDS: [kM, kK]
constexpr index_t a_lds_size = kM * kK * sizeof(XDataType);
constexpr index_t a_lds_size_aligned =
((a_lds_size + kLdsAlignmentInBytes - 1) / kLdsAlignmentInBytes) * kLdsAlignmentInBytes;
// B LDS: [kN, kK] for column-major or [kK, kN] for row-major
constexpr index_t b_lds_size = kN * kK * sizeof(PhiDataType);
return a_lds_size_aligned + b_lds_size;
}
// Grid configuration: 2D grid over (batch, output_dim)
@@ -80,8 +89,9 @@ struct MHCKernelV3
{
// 2D block indexing
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
const index_t block_m = get_block_id() / grid_n_size;
const index_t block_n = get_block_id() % grid_n_size;
const index_t block_id = get_block_id();
const index_t block_m = block_id / grid_n_size;
const index_t block_n = block_id % grid_n_size;
const index_t batch_start = block_m * kMTile;
const index_t out_start = block_n * kNTile;
@@ -89,54 +99,51 @@ struct MHCKernelV3
if(batch_start >= batch || out_start >= output_dim)
return;
// Create tensor views with adjusted pointers and dimensions
// The GEMM pipeline expects windows with origin {0,0} relative to the tensor view
const index_t remaining_batch = batch - batch_start;
const index_t remaining_output = output_dim - out_start;
// Create full tensor views (not adjusted) and use window origins to select regions
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
auto x_tensor_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_x + batch_start * nC, // Adjust pointer to start at this block's batch range
make_tuple(remaining_batch, nC), // Dimensions from this block's starting point
make_tuple(nC, 1),
number<1>{},
number<1>{});
// For column-major B [N, K], reinterpret row-major phi [nC, output_dim]
// as column-major [output_dim, nC] with strides [1, output_dim]
auto phi_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{});
auto phi_tensor_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_phi + out_start, // Adjust pointer to start at this block's output range
make_tuple(nC, remaining_output), // Dimensions from this block's starting point
make_tuple(remaining_output, 1),
number<1>{},
number<1>{});
// Pad tensors according to GEMM pipeline requirements
// For row-major A [M, K]: pad with sequence<false, kPadK>
auto x_tensor_padded =
pad_tensor_view(x_tensor_full,
make_tuple(number<kMTile>{}, number<kKTile>{}),
sequence<false, Problem::kPadK>{}); // Don't pad M, conditionally pad K
// Pad tensors to tile sizes to handle boundary conditions
auto x_tensor = pad_tensor_view(
x_tensor_unpadded, make_tuple(number<kMTile>{}, number<kKTile>{}), sequence<0, 1>{});
// For column-major B [N, K]: pad with sequence<false, kPadK>
auto phi_tensor_padded =
pad_tensor_view(phi_tensor_full,
make_tuple(number<kNTile>{}, number<kKTile>{}),
sequence<false, Problem::kPadK>{}); // Don't pad N, conditionally pad K
auto phi_tensor = pad_tensor_view(
phi_tensor_unpadded, make_tuple(number<kKTile>{}, number<kNTile>{}), sequence<0, 1>{});
// Create DRAM tile windows with origin {0, 0} relative to the padded tensor views
// The pipeline will internally manage K-dimension iteration
// Create DRAM tile windows from padded tensors
auto x_dram_window =
make_tile_window(x_tensor,
make_tile_window(x_tensor_padded,
make_tuple(number<kMTile>{}, number<kKTile>{}),
{0, 0}); // Origin at {0, 0} relative to the padded tensor view
{batch_start, 0}); // Start at this block's batch range
auto phi_dram_window =
make_tile_window(phi_tensor,
make_tuple(number<kKTile>{}, number<kNTile>{}),
{0, 0}); // Origin at {0, 0} relative to the padded tensor view
make_tile_window(phi_tensor_padded,
make_tuple(number<kNTile>{}, number<kKTile>{}),
{out_start, 0}); // Start at this block's output range
// Use GEMM pipeline v3 to compute the full GEMM
using GemmPipeline = GemmPipelineAgBgCrCompV3<Problem>;
// Use GEMM pipeline v1 to compute the full GEMM (more robust for multi-block execution)
using GemmPipeline = GemmPipelineAGmemBGmemCRegV1<Problem>;
const index_t num_k_loops = (nC + kKTile - 1) / kKTile;
extern __shared__ char smem[];
// Use static shared memory allocation (per-block, not shared across blocks!)
__shared__ char smem[GetSmemSize()];
auto gemm_pipeline = GemmPipeline{};
// V3 pipeline expects non-tuple windows and uses identity functions internally
auto result_tile = gemm_pipeline(x_dram_window, phi_dram_window, num_k_loops, smem);
// V1 pipeline expects tuple-wrapped windows
auto result_tile = gemm_pipeline(
make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem);
// Apply elementwise operations (currently commented out for GEMM testing)
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
@@ -183,31 +190,28 @@ struct MHCKernelV3
// Cast result to output data type
auto result_output = cast_tile<YDataType>(result_tile);
// Create output tensor view for efficient store_tile operation
// Create full output tensor view and use window origin
constexpr index_t output_vector_size = 16 / sizeof(YDataType);
auto output_tensor_view_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_output + batch_start * output_dim +
out_start, // Adjust pointer to this block's output region
make_tuple(remaining_batch,
remaining_output), // Dimensions from this block's starting point
make_tuple(output_dim, 1), // Strides: row-major layout
number<output_vector_size>{}, // Vector size for efficient memory access
number<1>{}); // Alignment
auto output_tensor_full =
make_naive_tensor_view<address_space_enum::global>(p_output,
make_tuple(batch, output_dim),
make_tuple(output_dim, 1),
number<output_vector_size>{},
number<1>{});
// Pad output tensor view to match the tile size (for boundary handling)
auto output_tensor_view = pad_tensor_view(output_tensor_view_unpadded,
make_tuple(number<kMTile>{}, number<kNTile>{}),
sequence<0, 1>{});
// Pad output tensor view for boundary handling (row-major C: sequence<false, kPadN>)
auto output_tensor_padded = pad_tensor_view(output_tensor_full,
make_tuple(number<kMTile>{}, number<kNTile>{}),
sequence<false, Problem::kPadN>{});
// Create tile window for the output using result_output's distribution
auto output_window = make_tile_window(
output_tensor_view,
make_tuple(number<kMTile>{}, number<kNTile>{}),
{0, 0}, // Origin at {0, 0} relative to the padded view
result_output.get_tile_distribution()); // Use distribution from result_output
// Create tile window with origin at this block's region
auto output_window = make_tile_window(output_tensor_padded,
make_tuple(number<kMTile>{}, number<kNTile>{}),
{batch_start, out_start},
result_output.get_tile_distribution());
// Store the result using the tile window (padding will prevent out-of-bounds writes)
// Store the result
store_tile(output_window, result_output);
}
};

View File

@@ -35,7 +35,8 @@ struct MHCProblem
// Layout types for BlockGemm
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim]
using BLayout =
ck_tile::tensor_layout::gemm::ColumnMajor; // phi treated as column-major for V1 pipeline
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
// For GEMM pipeline compatibility
@@ -48,9 +49,9 @@ struct MHCProblem
using BElementWise = identity;
static constexpr bool TransposeC = false;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false; // TESTING: Disable N padding
static constexpr bool kPadK = false;
static constexpr bool kPadM = true; // Enable padding to help with boundary conditions
static constexpr bool kPadN = true; // Enable padding
static constexpr bool kPadK = true; // Enable padding
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
@@ -64,7 +65,7 @@ struct MHCProblem
static constexpr index_t kBlockSize = BlockShape::BlockSize;
// Additional traits required by v3 pipeline
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = true; // Enable double buffering for multi-block
static constexpr bool UseStructuredSparsity = false;
static constexpr bool FixedVectorSize = false;