mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
CK Tile - small fix to hotloop scheduler & KPack value. (#1867)
* Use SmemPack in HotLoop scheduler
* Additional debug print information
* Change KPack value.
Hardcode for now, as without AK1/BK1 there's no good way to determine
its value.
* Fix HotLoopScheduler MFMA instr parameters.
[ROCm/composable_kernel commit: 0c15de6a83]
This commit is contained in:
@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "Kpack: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
||||
KPerXDL);
|
||||
|
||||
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
|
||||
"%d, %d\n C MFMA inst: %d\n",
|
||||
"%d, %d\n C MFMA inst: %d\n"
|
||||
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
|
||||
"%d/ %d\n",
|
||||
A_Buffer_Load_Inst_Num,
|
||||
B_Buffer_Load_Inst_Num,
|
||||
A_LDS_Write_Inst_Num,
|
||||
B_LDS_Write_Inst_Num,
|
||||
A_LDS_Read_Inst_Num,
|
||||
B_LDS_Read_Inst_Num,
|
||||
C_MFMA_Inst_Num);
|
||||
C_MFMA_Inst_Num,
|
||||
A_LDS_Read_Width,
|
||||
B_LDS_Read_Width,
|
||||
ALDSWriteWidth,
|
||||
BLDSWriteWidth,
|
||||
ABufferLoadWidth,
|
||||
BBufferLoadWidth);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
// should be at least equal to: WarpGemm::Impl::kABKPerLane
|
||||
// and the question is how to assess upper limit or exact value?
|
||||
// TODO: Should we introduce AK1/BK1 parameters ?
|
||||
static constexpr index_t KPack = 8;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * KPack;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
};
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
auto str = std::stringstream{};
|
||||
|
||||
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
|
||||
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
|
||||
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = KPerXDL;
|
||||
constexpr index_t B_LDS_Read_Width = KPerXDL;
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
|
||||
@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
Reference in New Issue
Block a user