Files
composable_kernel/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp
Damien Lejeune 43a5678fdf WIP: MHC v3
2026-02-05 13:04:18 +00:00

80 lines
3.0 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
struct MHCProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
// PhiDataType is the same as XDataType for the weight matrix
using PhiDataType = XDataType;
// BlockGemm compatibility - map our types to BlockGemm's expected types
using ADataType = XDataType; // Input matrix A
using BDataType = PhiDataType; // Weight matrix B (phi)
using CDataType = ComputeDataType; // Output/accumulator matrix C
// BlockGemmShape with kM, kN, kK members for BlockGemm
// Use supported warp gemm configuration for float32: 32x32x8
// We'll use 2 warps in M and 1 warp in N to get 64x32 block
using BlockGemmShape =
TileGemmShape<sequence<64, 32, 8>, // BlockTile (M, N, K)
sequence<2, 1, 1>, // BlockWarps (2 warps in M, 1 in N, 1 in K)
sequence<32, 32, 8>>; // WarpTile (matches available float32 MFMA)
// Layout types for BlockGemm
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim]
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
// For GEMM pipeline compatibility
using AsDataTypeTuple = tuple<ADataType>;
using BsDataTypeTuple = tuple<BDataType>;
using AsLayoutTuple = tuple<ALayout>;
using BsLayoutTuple = tuple<BLayout>;
using AElementWise = identity;
using BElementWise = identity;
static constexpr bool TransposeC = false;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false; // TESTING: Disable N padding
static constexpr bool kPadK = false;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t VectorLoadSize = 16;
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// kBlockSize for BlockGemm compatibility
static constexpr index_t kBlockSize = BlockShape::BlockSize;
// Additional traits required by v3 pipeline
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool FixedVectorSize = false;
struct Traits
{
static constexpr bool UsePersistentKernel = false;
};
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
};
} // namespace ck_tile