diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 5439bbe1f0..2bc5a4414e 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,10 +52,27 @@ inline std::string get_device_name() } } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + +inline bool is_gfx11_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || + ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || + ck::get_device_name() == "gfx1152"; +} + inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) + || is_gfx12_supported() || is_gfx11_supported() +#endif + ; } inline bool is_lds_direct_load_supported() @@ -67,7 +84,8 @@ inline bool is_lds_direct_load_supported() inline bool is_bf16_atomic_supported() { - return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + is_gfx12_supported(); } inline bool is_gfx101_supported() @@ -83,18 +101,5 @@ inline bool is_gfx103_supported() ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } -inline bool is_gfx11_supported() -{ - return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || - ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || - ck::get_device_name() == "gfx1152"; -} - -inline bool is_gfx12_supported() -{ - return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; -} - } // namespace ck #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index cd13dbb836..acd1d2ae49 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base using ThisThreadBlock = ThisThreadBlock; // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. - static constexpr index_t WaveSize = 64; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); @@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base return 1; }(); - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 0c030030fe..119f8a3306 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 69002d7962..80c65515e8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -626,13 +627,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index b5d6180ab3..7203348418 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -159,6 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_v3 { + template + static constexpr auto GetNXdlPerWave() + { + constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); + static_assert(MWaves > 0); + + constexpr index_t NWaves = Waves / MWaves; + if constexpr(NWaves == 0) + { + return 0; + } + else + { + if constexpr(NPerBlock % (NPerXDL * NWaves) == 0) + { + return NPerBlock / (NWaves * NPerXDL); + } + else + { + return 0; + } + } + } // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp( + reinterpret_cast(arg), + stream_config); + } + } + return 0; + } // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 && arg.KBatch > 1) + if(arg.KBatch > 1) { - return false; + if(is_gfx11_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v) + { + return false; + } + + if(sizeof(CDataType) == 1) + { + return false; + } + } + + if(is_gfx11_supported() || is_gfx12_supported()) + { + if(MPerXDL != 16 || NPerXDL != 16) + { + return false; + } + } + + if(is_gfx11_supported()) + { + if constexpr(std::is_same_v || + std::is_same_v) + { + return false; + } } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || @@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + else + { + return false; + } + } } // polymorphic @@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; + AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages; + AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride; + } + } + // clang-format off str << "DeviceGemmXdlUniversal" << "<" @@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2=32 && PackedSize != 2 on pre-gfx950 - if constexpr(static_cast(Arch::is_gfx950_build) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || - (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || - (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) +#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -78,23 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - enum struct Arch : bool +#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { -#if defined(__gfx950__) - is_gfx950_build = true, -#else - is_gfx950_build = false, -#endif - }; - // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 - if constexpr(static_cast(Arch::is_gfx950_build) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || - (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || - (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) - { - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -696,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC - << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 - << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + // clang-format off + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; + // clang-format off } index_t M; @@ -831,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -888,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -969,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -1022,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) @@ -1169,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + template + __device__ static bool constexpr IsValidCompilationParameter() + { + enum struct Arch : bool + { +#if defined(__gfx950__) + is_gfx950_build = true, +#else + is_gfx950_build = false, +#endif + }; + + // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 + if constexpr(static_cast(Arch::is_gfx950_build) || + (AK1Number < 32 && BK1Number < 32) || + (AK1Number >= 32 && APackedSize == 2) || + (BK1Number >= 32 && BPackedSize == 2)) + { + + } + else + { + return false; + } + + // Check tile size +#if defined(__gfx11__) || defined(__gfx12__) + if constexpr(MPerXdl != 16 || NPerXdl != 16) + { + return false; + } +#endif + // Check atomic caps +#if defined(__gfx11__) + constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set; +#else + constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation == + InMemoryDataOperationEnum::Set); +#endif + if constexpr(SupportMemOp == false) + { + return false; + } + + // Check tile size + if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + if constexpr(MWaves > 0 && NWaves > 0) + { + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + if constexpr(WaveSize == get_warp_size()) + { + return true; + } + else + { + return false; + } + } + else + { + return false; + } + } + else + { + return false; + } + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, - "Invalid tuning param!"); + if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0) + { + return false; + } + else + { + if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) || + (NPerBlock % (NXdlPerWave * NPerXdl) != 0)) + { + return false; + } + else + { + if(BlockwiseGemmPipe::WaveSize != get_warp_size()) + { + return false; + } + } + } if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 64d7f92750..2ce08e7044 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -6,6 +6,7 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" #include "ck/utility/amd_xdlops.hpp" +#include "ck/utility/amd_wmma.hpp" namespace ck { /** @@ -76,7 +77,21 @@ enum struct MfmaInstr mfma_f32_32x32x64f8f6f4, mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, - mfma_scale_f32_16x16x128f8f6f4 + mfma_scale_f32_16x16x128f8f6f4, + // gfx11 + wmma_f32_16x16x16_f16, + wmma_f32_16x16x16_bf16, + wmma_i32_16x16x16_iu8, + wmma_unsupport_16x16_gfx11, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, + wmma_unsupport_16x16_gfx12, }; template @@ -932,6 +947,175 @@ struct mfma_type } }; +// gfx11 +struct mfma_type_gfx11_base +{ + static constexpr index_t group_size = 8; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 8; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 32; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf16_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_i32_16x16x16_iu8_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + static constexpr index_t k_per_blk = 2; + template + __device__ void run(const FloatA&, const FloatB&, FloatC&) const + { + // empty for all unsupported types. + } +}; + +// gfx12 +struct mfma_type_gfx12_base +{ + static constexpr index_t group_size = 8; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 8; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 32; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + static constexpr index_t k_per_blk = 2; + template + __device__ void run(const FloatA&, const FloatB&, FloatC&) const + { + // empty for all unsupported types. + } +}; + template constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f64_16x16x4f64; +#endif } template <> @@ -993,7 +1183,13 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x4xf32; +#endif } template <> @@ -1026,7 +1222,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_f16; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x32f16; #else return MfmaInstr::mfma_f32_16x16x16f16; @@ -1036,7 +1236,13 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_f16; +#else return MfmaInstr::mfma_f32_16x16x16f16; +#endif } template <> @@ -1082,7 +1288,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_bf16; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x32bf16; #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -1094,7 +1304,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; #else return MfmaInstr::mfma_f32_16x16x8bf16; @@ -1126,7 +1340,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_i32_16x16x16_iu8; +#elif defined(__gfx950__) return MfmaInstr::mfma_i32_16x16x64i8; #elif defined(__gfx942__) return MfmaInstr::mfma_i32_16x16x32i8; @@ -1138,7 +1356,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_i32_16x16x16_iu8; +#elif defined(__gfx942__) || defined(__gfx950__) return MfmaInstr::mfma_i32_16x16x32i8; #else return MfmaInstr::mfma_i32_16x16x16i8; @@ -1186,13 +1408,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32f8f8; @@ -1263,13 +1495,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32bf8bf8; @@ -1295,13 +1537,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32f8bf8; @@ -1327,13 +1579,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32bf8f8; @@ -1355,10 +1617,18 @@ struct MfmaSelector static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, "n_per_blk != num_threads_per_blk"); - +#if defined(__gfx11__) + if constexpr(MPerXdlops == 16 && NPerXdlops == 16) + { + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + } +#else static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == selected_mfma.m_per_blk, "m_per_blk != num_input_blks * num_regs_per_blk"); +#endif static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || selected_mfma.num_output_blks == 1, @@ -1424,8 +1694,9 @@ struct XdlopsGemm static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - +#if defined(__HIP_DEVICE_COMPILE__) static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); +#endif } // XDL output supporting C = A * B @@ -1434,10 +1705,11 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1446,7 +1718,7 @@ struct XdlopsGemm make_pass_through_transform(M1), make_pass_through_transform(N1), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{})), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, @@ -1469,12 +1741,13 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); - const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); - const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1485,7 +1758,7 @@ struct XdlopsGemm make_pass_through_transform(M2), make_pass_through_transform(N2), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{})), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, @@ -1512,10 +1785,11 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1525,7 +1799,7 @@ struct XdlopsGemm make_pass_through_transform(N1), make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -1545,11 +1819,12 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) { - const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); - const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_g_m0_n0_m1_n1_m2_n2, @@ -1558,9 +1833,8 @@ struct XdlopsGemm make_pass_through_transform(N0), make_pass_through_transform(M1), make_pass_through_transform(N1), - make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, - mfma_instr.num_input_blks, - mfma_instr.group_size)), + make_unmerge_transform(make_tuple( + mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)), make_pass_through_transform(mfma_instr.num_threads_per_blk)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -1642,8 +1916,32 @@ struct XdlopsGemm __device__ static auto GetBlkIdx() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + template + __device__ static auto GetGfx11InputBlkIdx() + { + const auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk; + if constexpr(SwizzleA) + { + laneId = ((laneId & 1) << 3) | (laneId >> 1); + } constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform( make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), @@ -1661,8 +1959,12 @@ struct XdlopsGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); +#if defined(__gfx11__) + const auto blk_idx = GetGfx11InputBlkIdx(); +#else const auto blk_idx = GetBlkIdx(); +#endif const auto blk_id = blk_idx[I0]; const auto blk_td = blk_idx[I1]; @@ -1679,8 +1981,12 @@ struct XdlopsGemm __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); +#if defined(__gfx11__) + const auto blk_idx = GetGfx11InputBlkIdx(); +#else const auto blk_idx = GetBlkIdx(); +#endif const auto blk_id = blk_idx[I0]; const auto blk_td = blk_idx[I1]; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 861b81b1f6..63466a36f2 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -75,9 +75,9 @@ template struct BlockwiseGemmXdlops_pipeline_hotloop_inst { - static constexpr index_t WaveSize = 64; static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN; static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index fd0d1024b2..53e865767b 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,38 @@ namespace ck { +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) +__device__ constexpr index_t get_warp_size() +{ +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) + return 64; +#else + return 32; +#endif +#else + return 64; +#endif +} + +inline __host__ index_t get_warp_size() +{ +#if !(defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)) + int device = 0; + int result = 0; + auto status = hipGetDevice(&device); + if(status == hipSuccess) + { + status = hipDeviceGetAttribute(&result, hipDeviceAttributeWarpSize, device); + if(status == hipSuccess) + { + return result; + } + } +#endif + return 64; +} +#else __host__ __device__ constexpr index_t get_warp_size() { #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) @@ -15,6 +47,7 @@ __host__ __device__ constexpr index_t get_warp_size() return 32; #endif } +#endif __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }