// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/flatmm.hpp" #include "ck_tile/ops/gemm.hpp" // GEMM config with 16x16 warp tile struct MXfp4_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 512; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr int TileParitionerGroupNum = 8; static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; struct MXfp8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr int TileParitionerGroupNum = 8; static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; struct MXf8f4_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr int TileParitionerGroupNum = 8; static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; struct MXf4f8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr int TileParitionerGroupNum = 8; static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; template float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::stream_config& s);