// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" struct GemmConfigMemory { // Memory friendly for Interwave scheduler static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 32; static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 4; static constexpr ck_tile::index_t N_Warp = 1; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 { // Compute friendly for Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 { // Compute friendly for Intrawave scheduler // Using the ping pong reader in the lds level static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 32; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma { // Compute friendly for Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; template struct PipelineTypeTraits; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; template struct BatchedGemmTypeConfig; template <> struct BatchedGemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; using Types = BatchedGemmTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") .insert("n", "1024", "n dimension") .insert("k", "2048", "k dimension") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("batch_stride_a", "1048576", "Batch A stride") .insert("batch_stride_b", "2097152", "Batch B stride") .insert("batch_stride_c", "524288", "Batch C stride") .insert("batch_count", "8", "Batch count") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "cktile_batched_gemm.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } // host API float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);