Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 233e274077
commit 00247e3c29
83 changed files with 8193 additions and 2165 deletions

View File

@@ -35,6 +35,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ComputeTypeB = BDataType;
using AccType = float; // for now only support V_MFMA_SCALE_F32
static constexpr index_t APackedSize = packed_size_v<ComputeTypeA>;
static constexpr index_t BPackedSize = packed_size_v<ComputeTypeB>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -48,17 +51,24 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr index_t B_K1 =
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};
static constexpr auto xdlops_gemm = XdlopsGemm<ComputeTypeA,
MPerXDL,
NPerXDL,
KPack * APackedSize,
ComputeTypeB,
TransposeC,
true>{};
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16;
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
@@ -67,22 +77,29 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
// Hardcode to 2, for better 8-bit access pattern
static constexpr index_t MXdlPack = 2;
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< //
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops,
(packed_size_v<ComputeTypeA> > 1 || packed_size_v<ComputeTypeB> > 1)>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
@@ -116,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -127,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
@@ -142,24 +159,27 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(MRepeat / MXdlPack, MWaves, MXdlPack, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(NRepeat / NXdlPack, NWaves, NXdlPack, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
// We pack 2 mfma in M/N direction, so we need to divide by 2
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
@@ -179,13 +199,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
@@ -221,6 +240,28 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
I1,
I1,
Number<MXdlPack>{},
Number<NXdlPack>{},
M0,
M1,
M2,
N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
@@ -262,6 +303,23 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
@@ -314,45 +372,47 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<MRepeat / MXdlPack>{}, I1, Number<MXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<NRepeat / NXdlPack>{}, I1, Number<NXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_block_desc_m0_m1_m2_m3_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_block_desc_n0_n1_n2_n3_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;

View File

@@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::MWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;

View File

@@ -270,10 +270,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);

View File

@@ -58,11 +58,21 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t KGroup =
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<ComputeDataType>, f8_t>)
// On gfx950, we have mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
else
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

View File

@@ -0,0 +1,68 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -4,38 +4,9 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
return is_same_v<T, f8_ocp_t> || is_same_v<T, bf8_ocp_t> || is_same_v<T, f6_t> ||
is_same_v<T, bf6_t> || is_same_v<T, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
@@ -89,6 +60,30 @@ constexpr auto BlockGemmMXPipeline_Selector()
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_mx<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;

View File

@@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -136,15 +136,21 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::a_block_desc_m0_m1_m2_m3_k;
using Base::b_block_desc_n0_n1_n2_n3_k;
using Base::AMmaKStride;
using Base::APackedSize;
using Base::BMmaKStride;
using Base::BPackedSize;
using Base::KThreadChunk;
using Base::KXdlPack;
using Base::MXdlPack;
using Base::NXdlPack;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using Tuple5 = typename Base::Tuple5;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
@@ -156,11 +162,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto AScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto BScalesPerXdlopsRun =
(BPackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
static constexpr auto ScalesPerXdlopsRunPerThreadA =
AScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
static constexpr auto ScalesPerXdlopsRunPerThreadB =
BScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
using mx_scale_t = e8m0_bexp_t;
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -232,76 +253,58 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_buf);
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_buf);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
__builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT, LDS, GDS, Constant and Message
block_sync_lds();
// Initialize C
c_thread_buf.Clear();
@@ -314,13 +317,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
// wait previous blockwise copy to finish
// k indexes mapping to threads for 32x32x64:
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
@@ -335,160 +333,184 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// k = 0 k = 1
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
// load for next k loop
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
});
});
});
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_buf);
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_buf);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
__builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT and LGKM_CNT
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
@@ -497,87 +519,128 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
});
});
@@ -587,20 +650,16 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
make_tuple(Number<MRepeat / MXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThreadA * a_scale_thread_vec_size>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
make_tuple(Number<NRepeat / NXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThreadB * b_scale_thread_vec_size>{}));
protected:
using Base::a_thread_copy_;

View File

@@ -177,8 +177,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -179,7 +179,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

File diff suppressed because it is too large Load Diff

View File

@@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -42,10 +42,12 @@ namespace ck {
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector>
@@ -61,6 +63,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
@@ -96,8 +99,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
#if defined(__gfx950__)
int num_contiguous_dwords = 4;
#else
int num_contiguous_dwords = 1;
bool is_contiguous = true;
#endif
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
@@ -141,11 +148,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
static_assert(bytes_per_thread_load == dword_bytes,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data.");
// constexpr auto dword_bytes = 4;
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
// static_assert(bytes_per_thread_load == dword_bytes,
// "Direct load transfer requires each thread to load exactly a single "
// "DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
@@ -156,18 +163,45 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
static_assert(
AreThreadClusterLengthsValid(),
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
"wavefront to write contiguous DWORDs into LDS memory. ");
// static_assert(
// AreThreadClusterLengthsValid(),
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
// " "wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
constexpr auto wave_cluster_lengths = generate_sequence_v2(
[&](auto i) {
// FIXME: wave parallelism is not always in that dimension.
// The ThreadClusterLengths{} must be bigger than wave_num;
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
{
return Number<ThreadGroup::GetNumOfThread() / 64>{};
}
else
{
return I1;
}
},
Number<nDim>{});
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -215,7 +249,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
const bool is_src_valid =
@@ -303,7 +337,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
private:
static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{});
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -45,6 +45,44 @@ struct DeviceGemmMX : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmMX_BPreshuffle : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
ck::index_t StrideC,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -162,56 +163,108 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = conditional_t< //
!is_same_v<BLayout, tensor_layout::gemm::MFMA>,
GridwiseGemmMX_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>,
GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>>;
using Argument = typename GridwiseGemm::Argument;
@@ -304,385 +357,45 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
: 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
constexpr auto TailNumChoices = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
return Tuple<constant<TailNumber::Full>>{};
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
else
static_assert(false, "Unexpected BlkGemmPipelineVer!");
}();
constexpr bool Use2LDS = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
return false;
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
return true;
else
static_assert(false, "Unexpected BlkGemmPipelineVer!");
}();
const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
using BoolChoices = Tuple<ck::true_type, ck::false_type>;
static_for_product<BoolChoices,
BoolChoices,
remove_cvref_t<decltype(TailNumChoices)>>{}(
[&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) {
constexpr auto CGlobalMemoryDataOperation =
KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd
: InMemoryDataOperationEnum::Set;
if(mainloop_choice.value == has_main_k_block_loop &&
KBatch_cond_choice.value == (arg.KBatch > 1) &&
tail_num_choice.value == tail_num)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< //
Use2LDS,
GridwiseGemm,
mainloop_choice.value,
CGlobalMemoryDataOperation,
minimum_occupancy,
tail_num_choice.value>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
});
return ave_time;
}

View File

@@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsAddExtraN,

View File

@@ -315,6 +315,13 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<f4x2_pk_t, f4x2_pk_t>(f4x2_pk_t& y,
const f4x2_pk_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{

View File

@@ -173,18 +173,34 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
// setting
return make_naive_tensor_descriptor_packed(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1));
}
else
{
return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK1, Number<KPerBlock>{}, I1));
}
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
// setting
return make_naive_tensor_descriptor_packed(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1));
}
else
{
return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK1, Number<KPerBlock>{}, I1));
}
}
__host__ __device__ static constexpr auto
@@ -566,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
@@ -582,10 +600,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(

View File

@@ -256,8 +256,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
KPerBlock < 128 && MPerXdl == 16))
? true
: false;
static constexpr auto is_scale_mfma = false;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -184,8 +184,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
KPerBlock < 128 && MPerXdl == 16))
? true
: false;
static constexpr auto is_scale_mfma = false;

View File

@@ -173,15 +173,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
// On gfx950, we have a mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
else
return 1;
}();
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KPackPerGroup = KPack / KGroup;
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;

View File

@@ -76,10 +76,12 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BBlockLdsExtraN,
@@ -102,9 +104,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
@@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<MPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
// B matrix in LDS memory, dst of blockwise copy
@@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<NPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
@@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
FloatA,
ComputeType,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector>(
@@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
FloatB,
ComputeType,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector>(

View File

@@ -260,7 +260,8 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -422,6 +423,240 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcCoord src_coord_;
}; // namespace ck
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
index_t scale_gather_num,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2_gather
{
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
const SrcDesc& src_desc,
const Index& src_slice_origin_idx,
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
scale_gather_offsets_(scale_gather_offsets)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}(
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
constexpr auto current_dst_origin =
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
src_vector;
using src_vector_t =
typename vector_type_maker<SrcData,
SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
scale_gather_offsets_(gather_idx),
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
src_data_idx + i * src_scalar_step_in_vector);
constexpr auto full_dst_offset =
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
if constexpr(InvalidElementAsNaN)
{
dst_buf(full_dst_offset) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<full_dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
// blockIdx.y,
// threadIdx.x,
// dst_buf(Number<0>{}));
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
}; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
@@ -1053,10 +1288,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -1236,16 +1469,16 @@ struct ThreadwiseTensorSliceTransfer_v4
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);

View File

@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
}
};
template <index_t WaveNum, index_t nDim>
struct lambda_wave_cluster_dimension
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if((nDim - i) == 3)
return WaveNum;
else
return 1;
}
};
} // namespace detail
} // namespace ck

View File

@@ -90,7 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -99,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -444,6 +445,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
static_assert(!is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>,
"in-register transpose is not supported for f4x2_pk_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -105,7 +105,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -222,7 +223,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset;
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, true);

View File

@@ -410,8 +410,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(

View File

@@ -8,6 +8,35 @@
#include "ck/utility/amd_xdlops.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
using U = element_type_t<T>;
return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
enum struct MfmaInstr
{
@@ -847,6 +876,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -858,11 +889,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -885,6 +914,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -896,11 +927,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -1117,7 +1146,7 @@ struct MfmaSelector
#endif
}
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
template <>
@@ -1153,6 +1182,16 @@ struct MfmaSelector
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
@@ -1290,10 +1329,10 @@ struct MfmaSelector
#endif
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
static constexpr auto selected_mfma = mfma_type<GetMfma<element_type_t<base_type>,
MPerXdlops,
NPerXdlops,
additional_type,
element_type_t<additional_type>,
is_single_rate_mfma,
is_scale_mfma>()>{};
@@ -1375,7 +1414,8 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B
@@ -1413,6 +1453,49 @@ struct XdlopsGemm
Sequence<7>{}));
}
// XDL output supporting C = A * B
// M3_N3 -> M3_M4_M5_N3
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(M2),
make_pass_through_transform(N2),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
@@ -1518,7 +1601,13 @@ struct XdlopsGemm
});
}
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
template <index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave,
const ScaleA& a_scale_thread,
const FloatB& p_b_wave,
@@ -1528,12 +1617,12 @@ struct XdlopsGemm
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
}
});