mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
WIP: arbitrary batch dim
This commit is contained in:
@@ -9,40 +9,37 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_>
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
struct MHCProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
|
||||
// PhiDataType is the same as XDataType for the weight matrix
|
||||
using PhiDataType = XDataType;
|
||||
|
||||
|
||||
// BlockGemm compatibility - map our types to BlockGemm's expected types
|
||||
using ADataType = XDataType; // Input matrix A
|
||||
using BDataType = PhiDataType; // Weight matrix B (phi)
|
||||
using CDataType = ComputeDataType; // Output/accumulator matrix C
|
||||
|
||||
using ADataType = XDataType; // Input matrix A
|
||||
using BDataType = PhiDataType; // Weight matrix B (phi)
|
||||
using CDataType = ComputeDataType; // Output/accumulator matrix C
|
||||
|
||||
// BlockGemmShape with kM, kN, kK members for BlockGemm
|
||||
// BlockGemm expects windows to match exactly: A[kM, kK], B[kK, kN]
|
||||
// Our windows: x[16, 256], phi[256, 16]
|
||||
// Try matching to warp gemm size: kM=16, kN=16, kK=16
|
||||
// We'll need to iterate over K dimension
|
||||
using BlockGemmShape = MHCGemmShape<16, 16, 16>;
|
||||
|
||||
|
||||
// Keep original BlockShape for other uses
|
||||
// using BlockShape is already defined above
|
||||
|
||||
|
||||
// Layout types for BlockGemm
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [1, nC]
|
||||
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, n]
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
|
||||
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [1, nC]
|
||||
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, n]
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
|
||||
|
||||
// kBlockSize for BlockGemm compatibility
|
||||
static constexpr index_t kBlockSize = BlockShape::BlockSize;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user