mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
150 lines
4.5 KiB
C++
150 lines
4.5 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck/utility/common_header.hpp"
|
|
#include "ck/tensor_description/tensor_adaptor.hpp"
|
|
|
|
namespace ck {
|
|
|
|
enum struct BlockGemmPipelineVersion
|
|
{
|
|
// For GEMM
|
|
v1, // Naive
|
|
v2, // Mem
|
|
v3, // Comp
|
|
v4, // Comp, double lds buffer
|
|
v5, // Comp, double global prefetch register buffer
|
|
|
|
// For GEMM with preshuffled weight
|
|
// v1, single lds buffer
|
|
// v2, double lds buffer
|
|
};
|
|
enum struct BlockGemmPipelineScheduler
|
|
{
|
|
Intrawave,
|
|
Interwave,
|
|
};
|
|
|
|
enum struct TailNumber
|
|
{
|
|
// Single / Double buffer pipeline
|
|
Odd,
|
|
Even,
|
|
|
|
// Long prefetch pipeline, up to 8
|
|
One,
|
|
Two,
|
|
Three,
|
|
Four,
|
|
Five,
|
|
Six,
|
|
Seven,
|
|
|
|
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
|
Empty,
|
|
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
|
// prefetchstages
|
|
Full,
|
|
};
|
|
|
|
enum SchedulerGroup : uint32_t
|
|
{
|
|
SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions
|
|
SCHED_GROUP_VMEM = 0x020, // Global memory operations
|
|
SCHED_GROUP_LDS_READ = 0x100, // LDS read operations
|
|
SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations
|
|
};
|
|
|
|
template <index_t BlockSize,
|
|
index_t MPerBlock,
|
|
index_t NPerBlock,
|
|
index_t KPerBlock,
|
|
index_t ABufferLoadWidth,
|
|
index_t BBufferLoadWidth,
|
|
index_t ALDSWriteWidth,
|
|
index_t BLDSWriteWidth,
|
|
index_t ALDSReadWidth,
|
|
index_t BLDSReadWidth,
|
|
index_t MRepeat,
|
|
index_t NRepeat,
|
|
index_t MPerXDL,
|
|
index_t NPerXDL,
|
|
index_t KPerXDL,
|
|
bool IsF4F6 = false>
|
|
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
|
{
|
|
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
|
|
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
|
|
static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN;
|
|
|
|
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
|
|
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
|
|
|
|
static constexpr index_t A_Buffer_Load_Inst_Num =
|
|
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
|
|
static constexpr index_t B_Buffer_Load_Inst_Num =
|
|
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
|
|
|
|
static constexpr index_t A_LDS_Write_Inst_Num =
|
|
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
|
|
static constexpr index_t B_LDS_Write_Inst_Num =
|
|
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
|
|
|
|
static constexpr index_t A_LDS_Read_Inst_Num =
|
|
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
|
|
static constexpr index_t B_LDS_Read_Inst_Num =
|
|
WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
|
|
|
|
static constexpr index_t C_MFMA_Inst_Num =
|
|
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
|
|
|
static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1;
|
|
|
|
static constexpr index_t C_MFMA_Inst_Cycle = []() {
|
|
if constexpr(NPerXDL == 16)
|
|
{
|
|
return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp;
|
|
}
|
|
else if constexpr(NPerXDL == 32)
|
|
{
|
|
return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp;
|
|
}
|
|
}();
|
|
|
|
static constexpr auto Print()
|
|
{
|
|
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
|
|
BlockSize,
|
|
WaveSize,
|
|
MPerBlock,
|
|
NPerBlock,
|
|
KPerBlock,
|
|
MPerXDL,
|
|
NPerXDL,
|
|
KPerXDL);
|
|
|
|
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
|
|
"%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
|
|
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
|
|
"%d/ %d\n",
|
|
A_Buffer_Load_Inst_Num,
|
|
B_Buffer_Load_Inst_Num,
|
|
A_LDS_Write_Inst_Num,
|
|
B_LDS_Write_Inst_Num,
|
|
A_LDS_Read_Inst_Num,
|
|
B_LDS_Read_Inst_Num,
|
|
C_MFMA_Inst_Num,
|
|
C_MFMA_Inst_Cycle,
|
|
A_LDS_Read_Width,
|
|
B_LDS_Read_Width,
|
|
ALDSWriteWidth,
|
|
BLDSWriteWidth,
|
|
ABufferLoadWidth,
|
|
BBufferLoadWidth);
|
|
}
|
|
};
|
|
|
|
} // namespace ck
|