WIP: MHC v3

This commit is contained in:
Damien Lejeune
2026-02-05 13:04:18 +00:00
parent 6ea40157f1
commit 43a5678fdf
13 changed files with 957 additions and 41 deletions

View File

@@ -26,22 +26,54 @@ struct MHCProblem
using CDataType = ComputeDataType; // Output/accumulator matrix C
// BlockGemmShape with kM, kN, kK members for BlockGemm
// BlockGemm expects windows to match exactly: A[kM, kK], B[kK, kN]
// Our windows: x[16, 256], phi[256, 16]
// Try matching to warp gemm size: kM=16, kN=16, kK=16
// We'll need to iterate over K dimension
using BlockGemmShape = MHCGemmShape<16, 16, 16>;
// Keep original BlockShape for other uses
// using BlockShape is already defined above
// Use supported warp gemm configuration for float32: 32x32x8
// We'll use 2 warps in M and 1 warp in N to get 64x32 block
using BlockGemmShape =
TileGemmShape<sequence<64, 32, 8>, // BlockTile (M, N, K)
sequence<2, 1, 1>, // BlockWarps (2 warps in M, 1 in N, 1 in K)
sequence<32, 32, 8>>; // WarpTile (matches available float32 MFMA)
// Layout types for BlockGemm
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [1, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, n]
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim]
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
// For GEMM pipeline compatibility
using AsDataTypeTuple = tuple<ADataType>;
using BsDataTypeTuple = tuple<BDataType>;
using AsLayoutTuple = tuple<ALayout>;
using BsLayoutTuple = tuple<BLayout>;
using AElementWise = identity;
using BElementWise = identity;
static constexpr bool TransposeC = false;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false; // TESTING: Disable N padding
static constexpr bool kPadK = false;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t VectorLoadSize = 16;
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// kBlockSize for BlockGemm compatibility
static constexpr index_t kBlockSize = BlockShape::BlockSize;
// Additional traits required by v3 pipeline
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool FixedVectorSize = false;
struct Traits
{
static constexpr bool UsePersistentKernel = false;
};
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
};
} // namespace ck_tile