// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" using ADataType = ck::bf6x16_pk_t; using BDataType = ck::bf6x16_pk_t; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; using CDataType = ck::half_t; using AccDataType = float; using CShuffleDataType = CDataType; using ALayout = Row; using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; // elementwise transformation for A matrix using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix constexpr ck::index_t DataPackedSize = 16; // Packed representation of data constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout BLayout, // BLayout CLayout, // CLayout ADataType, // ADataType XPackedDataType, // AScaleDataType BDataType, // BDataType XPackedDataType, // BScaleDataType CDataType, // CDataType AccDataType, // GemmAccDataType CShuffleDataType, // CShuffleDataType AElementOp, // AElementwiseOperation BElementOp, // BElementwiseOperation CElementOp, // CElementwiseOperation GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size 256, // BlockSize: Thread block size 128, // MPerBlock 128, // NPerBlock KPerBlock, // KPerBlock 1, // AK1 1, // BK1 16, // MPerXDL 16, // NPerXDL 4, // MXdlPerWave 4, // NXdlPerWave S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN 2, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer ADataType, // ComputeTypeA BDataType // ComputeTypeB >; int main(int argc, char* argv[]) { return run_mx_gemm_example(argc, argv) ? 0 : -1; }