// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { enum struct BlockGemmPipelineVersion { // For GEMM v1, // Naive v2, // Mem v3, // Comp v4, // Comp, double lds buffer v5, // Comp, double global prefetch register buffer // For GEMM with preshuffled weight // v1, single lds buffer // v2, double lds buffer }; enum struct BlockGemmPipelineScheduler { Intrawave, Interwave, }; enum struct TailNumber { // Single / Double buffer pipeline Odd, Even, // Long prefetch pipeline, up to 8 One, Two, Three, Four, Five, Six, Seven, // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages Empty, // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add // prefetchstages Full, }; enum SchedulerGroup : uint32_t { SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions SCHED_GROUP_VMEM = 0x020, // Global memory operations SCHED_GROUP_LDS_READ = 0x100, // LDS read operations SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations }; template struct BlockwiseGemmXdlops_pipeline_hotloop_inst { static constexpr index_t WaveSize = 64; static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; static constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); static constexpr index_t B_Buffer_Load_Inst_Num = NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); static constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); static constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); static constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); static constexpr index_t B_LDS_Read_Inst_Num = WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); static constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1; static constexpr index_t C_MFMA_Inst_Cycle = []() { if constexpr(NPerXDL == 16) { return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp; } else if constexpr(NPerXDL == 32) { return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp; } }(); static constexpr auto Print() { printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", BlockSize, WaveSize, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n" "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " "%d/ %d\n", A_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num, A_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num, A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, C_MFMA_Inst_Num, C_MFMA_Inst_Cycle, A_LDS_Read_Width, B_LDS_Read_Width, ALDSWriteWidth, BLDSWriteWidth, ABufferLoadWidth, BBufferLoadWidth); } }; } // namespace ck