// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" struct GemmConfigBase { static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; }; template struct GemmConfigMemoryInterwave : public GemmConfigBase { // Memory friendly for Interwave scheduler static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 32; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 4; static constexpr ck_tile::index_t N_Warp = 1; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template struct GemmConfigMemoryIntrawave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 32; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 4; static constexpr ck_tile::index_t N_Warp = 1; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template struct GemmConfigComputeV3 : public GemmConfigBase { // Compute V3 only support Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template struct GemmConfigComputeV3_1 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; template struct GemmConfigComputeV3_WMMA : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 4; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; template struct GemmConfigComputeV4 : public GemmConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template struct GemmConfigComputeV4_1 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template struct GemmConfigComputeV5 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 1; static constexpr ck_tile::index_t K_Warp = 2; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaveGroups = 2; }; template struct GemmConfigComputeV6 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 32; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6; static constexpr ck_tile::index_t NumWaveGroups = 1; }; template struct GemmConfigPreshuffleDecode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template struct GemmConfigPreshufflePrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; }; template struct GemmTypeConfig; template <> struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; template <> struct GemmTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; using AccDataType = float; using CDataType = ck_tile::bf16_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::pk_int4_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::pk_int4_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::pk_int4_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::int8_t; using BDataType = ck_tile::int8_t; using AccDataType = int32_t; using CDataType = int32_t; }; template struct PipelineTypeTraits; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; template using UniversalGemmPipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; inline auto create_args() { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("persistent", "0", "0:non-persistent, 1:persistent") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "gemm.json", "json file name to dump results") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") .insert("rotating_count", "1000", "rotating count, defaults to 1000") .insert("test_async", "0", "0: normal gemm, 1: test async input scheduler"); return arg_parser; } // host API template float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);