mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
WIP: MHC v3
This commit is contained in:
@@ -26,22 +26,54 @@ struct MHCProblem
|
||||
using CDataType = ComputeDataType; // Output/accumulator matrix C
|
||||
|
||||
// BlockGemmShape with kM, kN, kK members for BlockGemm
|
||||
// BlockGemm expects windows to match exactly: A[kM, kK], B[kK, kN]
|
||||
// Our windows: x[16, 256], phi[256, 16]
|
||||
// Try matching to warp gemm size: kM=16, kN=16, kK=16
|
||||
// We'll need to iterate over K dimension
|
||||
using BlockGemmShape = MHCGemmShape<16, 16, 16>;
|
||||
|
||||
// Keep original BlockShape for other uses
|
||||
// using BlockShape is already defined above
|
||||
// Use supported warp gemm configuration for float32: 32x32x8
|
||||
// We'll use 2 warps in M and 1 warp in N to get 64x32 block
|
||||
using BlockGemmShape =
|
||||
TileGemmShape<sequence<64, 32, 8>, // BlockTile (M, N, K)
|
||||
sequence<2, 1, 1>, // BlockWarps (2 warps in M, 1 in N, 1 in K)
|
||||
sequence<32, 32, 8>>; // WarpTile (matches available float32 MFMA)
|
||||
|
||||
// Layout types for BlockGemm
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [1, nC]
|
||||
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, n]
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC]
|
||||
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim]
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
|
||||
|
||||
// For GEMM pipeline compatibility
|
||||
using AsDataTypeTuple = tuple<ADataType>;
|
||||
using BsDataTypeTuple = tuple<BDataType>;
|
||||
using AsLayoutTuple = tuple<ALayout>;
|
||||
using BsLayoutTuple = tuple<BLayout>;
|
||||
|
||||
using AElementWise = identity;
|
||||
using BElementWise = identity;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false; // TESTING: Disable N padding
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr index_t VectorLoadSize = 16;
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 4;
|
||||
|
||||
// kBlockSize for BlockGemm compatibility
|
||||
static constexpr index_t kBlockSize = BlockShape::BlockSize;
|
||||
|
||||
// Additional traits required by v3 pipeline
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool FixedVectorSize = false;
|
||||
|
||||
struct Traits
|
||||
{
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user