// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "mx_flatmm_instance.hpp" // clang-format off #define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@ #define A_DATA_TYPE @A_DATA_TYPE@ #define B_DATA_TYPE @B_DATA_TYPE@ #define C_DATA_TYPE @C_DATA_TYPE@ #define A_LAYOUT @A_LAYOUT@ #define B_LAYOUT @B_LAYOUT@ #define C_LAYOUT @C_LAYOUT@ #define PERSISTENT @PERSISTENT@ #define SPLIT_K @SPLIT_K@ #define HAS_HOT_LOOP @HAS_HOT_LOOP@ #define TAIL_NUMBER @TAIL_NUMBER@ // clang-format on using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; using ROW = ck_tile::tensor_layout::gemm::RowMajor; using COL = ck_tile::tensor_layout::gemm::ColumnMajor; using ScaleType = ck_tile::e8m0_t; inline constexpr auto ODD = ck_tile::TailNumber::Odd; inline constexpr auto EVEN = ck_tile::TailNumber::Even; inline constexpr int ScaleGranularityM = 1; inline constexpr int ScaleGranularityN = 1; inline constexpr int ScaleGranularityK = 32; using ScaleM = ck_tile::FlatmmScalePointer; using ScaleN = ck_tile::FlatmmScalePointer; template float mx_flatmm_calc, /*AccDataType*/ float, C_DATA_TYPE, A_LAYOUT, B_LAYOUT, /*DsLayout*/ ck_tile::tuple<>, C_LAYOUT, ScaleM, ScaleN, PERSISTENT, /*CDEElementWise*/ ck_tile::element_wise::PassThrough, SPLIT_K, HAS_HOT_LOOP, TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::stream_config& s);