DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096)

* Prepare files for DeviceGemm_Wmma_CShuffleV3

* Implement main part of CShuffleV3 with block pipeline v3 for WMMA

* Remove unused functions and template params for A/B descriptors

* Support both gfx11 and gfx12

* Enable SplitK for gfx12 and disable for gfx11

* Added RowColRow layout for DeviceGemmV2 fp16

* Added more instances for Row, Col, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout

* Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout

* Fix formatting

* Add documentation

Based on e5ad48a784

* Enable gemm_universal profiling for gfx11/12

* Add WMMA intrinsics for F8/BF8

* Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances

* Add BF16 instances and tests

* Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
This commit is contained in:
Anton Gorenko
2025-04-28 11:14:21 +06:00
committed by GitHub
parent 8add2cf45d
commit edd92fc546
44 changed files with 5326 additions and 570 deletions

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmWmmaops_pipeline_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
}
else
{
static_assert(false, "BlockGemmPipeline configuration is not available");
}
}
} // namespace ck

View File

@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t ABufferLoadWidth,
index_t BBufferLoadWidth,
index_t ALDSWriteWidth,
index_t BLDSWriteWidth,
index_t ALDSReadWidth,
index_t BLDSReadWidth,
index_t MRepeat,
index_t NRepeat,
index_t MPerWmma,
index_t NPerWmma,
index_t KPerWmma>
struct BlockwiseGemmWmmaops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 32;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma);
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
static constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
static constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
static constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
static constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
static constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
static constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerWmma * NPerWmma * KPerWmma);
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n",
BlockSize,
WaveSize,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
KPerWmma);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C WMMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d, %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_WMMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
}
};
} // namespace ck

View File

@@ -0,0 +1,309 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I5 = Number<5>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = 32;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
#if defined(__gfx12__)
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
#else
static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1;
#endif
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
static constexpr auto wmma_gemm =
WmmaGemm<ADataType, BDataType, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
static constexpr index_t KRepeat = KPerBlock / KPack;
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
using HotLoopInstList =
ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerWmma,
NPerWmma,
wmma_gemm.wmma_instr.k_per_wmma>;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
#if defined(__gfx12__)
const auto wmma_krow = wmma_gemm.GetSubGroupId();
#else
const auto wmma_krow = 0;
#endif
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
#if defined(__gfx12__)
const auto wmma_krow = wmma_gemm.GetSubGroupId();
#else
const auto wmma_krow = 0;
#endif
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
*
* This constructor initializes the thread copy objects for matrices A and B.
* It also performs several compile-time checks to ensure the correctness of the
* matrix tile descriptors.
*
* @param a_origin The origin data index for matrix A.
* @param b_origin The origin data index for matrix B.
*
* @note The constructor includes static assertions to ensure that:
* - The matrix tile descriptors for A and B are known at compile-time.
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
* WaveSize.
* - The dimensions of the block are divisible by the product of the corresponding WMMA and
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
BWmmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
NPerBlock % (NPerWmma * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWmma>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWmma>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
Number<KRepeat>{},
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack * A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
Number<KRepeat>{},
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack * B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWmma]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ADataType,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
using BThreadCopy =
ThreadwiseTensorSliceTransfer_v4<BDataType,
BDataType,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck

View File

@@ -0,0 +1,466 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp"
namespace ck {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWmmaops_pipeline_v3
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::A_K1;
using Base::A_KRow;
using Base::B_K1;
using Base::B_KRow;
using Base::KRepeat;
using Base::WmmaK;
using Base::wmma_gemm;
using typename Base::HotLoopInstList;
using Base::CalculateCThreadOriginDataIndex;
using Base::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::GetCThreadBuffer;
using Base::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
__device__ static constexpr auto HotLoopScheduler()
{
// TODO: Calculation of the number of instructions may require changes for WMMA
/*
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
? HotLoopInstList::B_LDS_Read_Inst_Num
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_wmma_rate =
(wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_wmma_rate =
(wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_wmma =
(num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
constexpr auto num_dsread_b_wmma =
(num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
// stage 1
// Separate this part?
// constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
constexpr auto num_wmma_per_issue =
num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
});
// stage 2
static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
ds_read_a_wmma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
ds_read_a_wmma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
ds_read_b_wmma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
ds_read_b_wmma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
*/
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(ik / A_K1, m0, k0, 0, 0, ik % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(ik / B_K1, n0, k0, 0, 0, ik % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(ik / A_K1, m0, k0, 0, 0, ik % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(ik / B_K1, n0, k0, 0, 0, ik % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
} // namespace ck