mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)
This commit is contained in:
@@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
@@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
using HotLoopInstList =
|
||||
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
|
||||
MPerBlock,
|
||||
@@ -219,6 +218,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
@@ -227,6 +227,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
@@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
@@ -626,13 +627,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -159,6 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
#if !defined(__gfx11__) && !defined(__gfx12__)
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
@@ -260,6 +261,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
|
||||
@@ -176,8 +176,36 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
template <bool isWave64>
|
||||
static constexpr auto GetNXdlPerWave()
|
||||
{
|
||||
constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32;
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
|
||||
constexpr index_t NWaves = Waves / MWaves;
|
||||
if constexpr(NWaves == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NPerBlock % (NPerXDL * NWaves) == 0)
|
||||
{
|
||||
return NPerBlock / (NWaves * NPerXDL);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -226,8 +254,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using Argument = typename GridwiseGemm64::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
@@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
template <typename GridwiseGemm>
|
||||
float RunImp(const typename GridwiseGemm::Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
auto arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
@@ -297,7 +324,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
@@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),
|
||||
stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
if(is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(sizeof(CDataType) == 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(is_gfx11_supported() || is_gfx12_supported())
|
||||
{
|
||||
if(MPerXDL != 16 || NPerXDL != 16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(is_gfx11_supported())
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
|
||||
std::is_same_v<ADataType, ck::bf8_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
@@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
return GridwiseGemm64::CheckValidity(arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
return GridwiseGemm32::CheckValidity(
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
index_t PrefetchStages = 0;
|
||||
index_t AMmaKStride = 0;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
|
||||
AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
|
||||
AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
@@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< PrefetchStages << ", "
|
||||
<< "Kpack: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
|
||||
<< AMmaKStride;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -35,20 +35,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
enum struct Arch : bool
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
|
||||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
|
||||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
|
||||
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -78,23 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
enum struct Arch : bool
|
||||
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
|
||||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
|
||||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
|
||||
{
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -696,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
|
||||
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
|
||||
// clang-format off
|
||||
std::cout << "problem {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", "
|
||||
<< "KP:" << KPadded << ", "
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << "}" << std::endl;
|
||||
// clang-format off
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -831,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
@@ -888,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -969,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
@@ -1022,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
@@ -1169,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__device__ static bool constexpr IsValidCompilationParameter()
|
||||
{
|
||||
enum struct Arch : bool
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(AK1Number < 32 && BK1Number < 32) ||
|
||||
(AK1Number >= 32 && APackedSize == 2) ||
|
||||
(BK1Number >= 32 && BPackedSize == 2))
|
||||
{
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check tile size
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(MPerXdl != 16 || NPerXdl != 16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
// Check atomic caps
|
||||
#if defined(__gfx11__)
|
||||
constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set;
|
||||
#else
|
||||
constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
|
||||
InMemoryDataOperationEnum::Set);
|
||||
#endif
|
||||
if constexpr(SupportMemOp == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check tile size
|
||||
if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
if constexpr(MWaves > 0 && NWaves > 0)
|
||||
{
|
||||
constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
|
||||
if constexpr(WaveSize == get_warp_size())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl) != 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(BlockwiseGemmPipe::WaveSize != get_warp_size())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/amd_xdlops.hpp"
|
||||
#include "ck/utility/amd_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
/**
|
||||
@@ -76,7 +77,21 @@ enum struct MfmaInstr
|
||||
mfma_f32_32x32x64f8f6f4,
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
wmma_i32_16x16x16_iu8,
|
||||
wmma_unsupport_16x16_gfx11,
|
||||
// gfx12
|
||||
wmma_f32_16x16x16_f16_gfx12,
|
||||
wmma_f32_16x16x16_bf16_gfx12,
|
||||
wmma_i32_16x16x16_iu8_gfx12,
|
||||
wmma_f32_16x16x16_f8f8_gfx12,
|
||||
wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
wmma_unsupport_16x16_gfx12,
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -932,6 +947,175 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
static constexpr index_t group_size = 8;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 8;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 16;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx11> : public mfma_type_gfx11_base
|
||||
{
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
|
||||
{
|
||||
// empty for all unsupported types.
|
||||
}
|
||||
};
|
||||
|
||||
// gfx12
|
||||
struct mfma_type_gfx12_base
|
||||
{
|
||||
static constexpr index_t group_size = 8;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 8;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
|
||||
{
|
||||
// empty for all unsupported types.
|
||||
}
|
||||
};
|
||||
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
@@ -951,7 +1135,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<double, 16, 16>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f64_16x16x4f64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -993,7 +1183,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 16, 16>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1026,7 +1222,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32f16;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
@@ -1036,7 +1236,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1082,7 +1288,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32bf16;
|
||||
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
@@ -1094,7 +1304,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
@@ -1126,7 +1340,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x64i8;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
@@ -1138,7 +1356,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#elif defined(__gfx942__) || defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
#else
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
@@ -1186,13 +1408,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
@@ -1263,13 +1495,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
@@ -1295,13 +1537,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
@@ -1327,13 +1579,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
@@ -1355,10 +1617,18 @@ struct MfmaSelector
|
||||
|
||||
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
|
||||
"n_per_blk != num_threads_per_blk");
|
||||
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
|
||||
{
|
||||
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
|
||||
selected_mfma.m_per_blk,
|
||||
"m_per_blk != num_input_blks * num_regs_per_blk");
|
||||
}
|
||||
#else
|
||||
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
|
||||
selected_mfma.m_per_blk,
|
||||
"m_per_blk != num_input_blks * num_regs_per_blk");
|
||||
#endif
|
||||
|
||||
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
|
||||
selected_mfma.num_output_blks == 1,
|
||||
@@ -1424,8 +1694,9 @@ struct XdlopsGemm
|
||||
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
#endif
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
@@ -1434,10 +1705,11 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1446,7 +1718,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -1469,12 +1741,13 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
|
||||
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1485,7 +1758,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(M2),
|
||||
make_pass_through_transform(N2),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -1512,10 +1785,11 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1525,7 +1799,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(N1),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{}),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -1545,11 +1819,12 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_g_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1558,9 +1833,8 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
|
||||
mfma_instr.num_input_blks,
|
||||
mfma_instr.group_size)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
|
||||
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -1642,8 +1916,32 @@ struct XdlopsGemm
|
||||
|
||||
__device__ static auto GetBlkIdx()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto blk_idx =
|
||||
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
|
||||
|
||||
const auto blk_id = blk_idx[I1];
|
||||
const auto blk_td = blk_idx[I2];
|
||||
|
||||
return make_tuple(blk_id, blk_td);
|
||||
}
|
||||
|
||||
template <bool SwizzleA>
|
||||
__device__ static auto GetGfx11InputBlkIdx()
|
||||
{
|
||||
const auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
|
||||
if constexpr(SwizzleA)
|
||||
{
|
||||
laneId = ((laneId & 1) << 3) | (laneId >> 1);
|
||||
}
|
||||
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
|
||||
@@ -1661,8 +1959,12 @@ struct XdlopsGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<true>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
|
||||
const auto blk_id = blk_idx[I0];
|
||||
const auto blk_td = blk_idx[I1];
|
||||
@@ -1679,8 +1981,12 @@ struct XdlopsGemm
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<false>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
|
||||
const auto blk_id = blk_idx[I0];
|
||||
const auto blk_td = blk_idx[I1];
|
||||
|
||||
Reference in New Issue
Block a user