// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" template struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> { using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>; MXGemmHostArgs(const void* a_ptr, const void* b_ptr, void* c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ScaleM scale_m_, ScaleN scale_n_) : Base({a_ptr}, {b_ptr}, {}, c_ptr_, k_batch_, M_, N_, K_, {stride_A_}, {stride_B_}, {}, stride_C_), scale_m(scale_m_), scale_n(scale_n_) { } ScaleM scale_m; ScaleN scale_n; }; // GEMM config with 16x16 warp tile struct MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr int TileParitionerGroupNum = 8; static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; struct MXfp4_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; }; // GEMM config with 16x16 warp tile struct MXfp8_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; };