// Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" struct GemmConfigBase { static constexpr bool kPadM = true; static constexpr bool kPadN = true; static constexpr bool kPadK = true; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr bool Persistent = false; static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool DoubleSmemBuffer = false; }; template struct GemmConfigMemoryInterwave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 32; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; template struct StreamKGemmTypeConfig { using ADataType = ADataType_; using BDataType = BDataType_; using AccDataType = float; using CDataType = CDataType_; }; template struct DataTypeTraits; template <> struct DataTypeTraits { static constexpr const char* name = "fp32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf16"; }; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") .insert("n", "512", "n dimension") .insert("k", "512", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("num_sk_blocks", "-1", "number of Stream-K blocks. -1: chosen by algorithm, or user selected") .insert("reduction_strategy", "atomic", "strategy for storing results in C tensor - atomic/reduction") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16") .insert("warmup", "50", "number of iterations before benchmarking the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); }