diff --git a/CMakeLists.txt b/CMakeLists.txt index ba57ead09a..4e12462a41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,7 +202,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") set(CK_USE_XDL "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") - message("Enabling FP8 gemms on native architectures") + message("Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() @@ -211,6 +211,11 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1 add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message("Enabling WMMA FP8 gemms on native architectures") + add_definitions(-DCK_USE_WMMA_FP8) + set(CK_USE_WMMA_FP8 "ON") +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 83b76382bc..e38f166c1a 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -125,7 +125,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx9__) // for GPU code +#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 994e60025d..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -115,6 +115,10 @@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif +#ifndef CK_USE_WMMA_FP8 +#cmakedefine CK_USE_WMMA_FP8 @CK_USE_WMMA_FP8@ +#endif + #ifndef CK_USE_GFX94 #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp new file mode 100644 index 0000000000..2fdabc6bc7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmWmmaops_pipeline_v3{}; + } + else + { + static_assert(false, "BlockGemmPipeline configuration is not available"); + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp new file mode 100644 index 0000000000..31c4729760 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 32; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerWmma * NPerWmma * KPerWmma); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + KPerWmma); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C WMMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d, %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_WMMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp new file mode 100644 index 0000000000..a63d32802e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I5 = Number<5>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = 32; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + +#if defined(__gfx12__) + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; +#else + static constexpr index_t A_KRow = 1; + static constexpr index_t B_KRow = 1; +#endif + + static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + + static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); + static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t KRepeat = KPerBlock / KPack; + + static constexpr auto WmmaK = Number{}; + + using HotLoopInstList = + ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst; + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding WMMA and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && + BWmmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 && + NPerBlock % (NPerWmma * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1; + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWmma] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp new file mode 100644 index 0000000000..2fb95f0f8d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmWmmaops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v3 + : BlockwiseGemmWmmaops_pipeline_base +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Calculation of the number of instructions may require changes for WMMA + /* + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num; + + constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_wmma = + (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate; + constexpr auto num_dsread_b_wmma = + (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma); + constexpr auto num_wmma_per_issue = + num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA + }); + + // stage 2 + static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >= + ds_read_a_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_wmma - 1) * + ds_read_a_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + + static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >= + ds_read_b_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_wmma - 1) * + ds_read_b_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + */ + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..1ef8a9b8ad --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // TODO: Implement + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With universal GEMM +/// there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct GridwiseGemm_wmma_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_N_K1 -> K0_MNRepeat_MNWaves_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 1abae56be4..429df2413f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,10 @@ enum struct WmmaInstr wmma_f32_16x16x16_f16_gfx12, wmma_f32_16x16x16_bf16_gfx12, wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, }; /* @@ -400,6 +404,146 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + template + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; + } + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround static constexpr auto selected_wmma = wmma_type(), Number<32>{}>{}; @@ -612,14 +781,17 @@ struct WmmaGemm (is_same::value && is_same::value && is_same::value) || (is_same::value && is_same::value && - is_same::value) + is_same::value) || + ((is_same::value || is_same::value) && + (is_same::value || is_same::value) && + is_same::value) || #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value && - is_same::value) + (is_same::value && is_same::value && + is_same::value) || #endif - , + false, "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " - "(int8, int32) or (int4, int32)!"); + "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!"); static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { if constexpr(!TransposeC) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 317f324e6d..62e3220b5a 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index aa519fb2be..e14c0d62a8 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP @@ -341,5 +341,101 @@ struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> } }; +// src: f8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp index 4218c51ca3..79212e16dd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,521 +7,22 @@ #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_WMMA +#include "gemm_universal_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_universal_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif template > op_ptrs; +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + op_ptrs); + } + } +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA + +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -822,6 +399,7 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -831,7 +409,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); } } - +#endif +#ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -842,6 +421,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#endif // CK_USE_XDL return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc new file mode 100644 index 0000000000..1396437326 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc new file mode 100644 index 0000000000..f0de713834 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 70e54962ed..fe35d9ca76 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -81,21 +81,29 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") + message("removing gemm_multiply_multiply_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") - message("removing gemm_multiply_multiply_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() + if(NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "gemm_wmma_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endforeach() - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") - message("removing gemm_universal_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - endif() + message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) set(INST_OBJ) @@ -124,7 +132,10 @@ function(add_instance_library INSTANCE_NAME) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - endif() + endif() + endif() + if(source MATCHES "gemm_wmma_universal" AND source MATCHES "f8") + list(FILTER INST_TARGETS INCLUDE REGEX "gfx12") endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) @@ -455,4 +466,3 @@ set(DEV_OPS_INC_DIRS ${PROJECT_SOURCE_DIR}/library/include/ck/ ) rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) - diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index ade65eacf3..18eeefa522 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -1,7 +1,17 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) -list(APPEND GEMM_UNIVERSAL_INSTANCES +list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp + + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -18,7 +28,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp - + device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -57,6 +67,16 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp ) +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -80,6 +100,9 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -134,25 +157,28 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp ) - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp new file mode 100644 index 0000000000..5d3bb3f7b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..c9a730de68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp new file mode 100644 index 0000000000..6c3a641f9f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..cd88edec59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..b700e78d3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..9951c02251 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7b4cd64d33 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3a607c4178 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp new file mode 100644 index 0000000000..3751dc5a11 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3971802415 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp new file mode 100644 index 0000000000..222b49eb7d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..36901b4f38 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..6960375ed6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..bbc8b92217 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7f71cf6f59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..331ca8b2ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..2fca3551b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..5087a9d719 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..244eb69190 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..89df765517 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 2054ffbbb3..f7b1d5f1f8 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal.hpp" diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9cb70e4670..17c8c277eb 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,7 +58,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) @@ -76,6 +75,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() + list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) @@ -144,7 +144,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) @@ -170,6 +169,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index a22d983da5..7f2393a7e6 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -103,8 +103,10 @@ int profile_gemm_universal(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) using F8 = ck::f8_t; +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using I4 = ck::pk_i4_t; #endif @@ -201,7 +203,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); @@ -210,6 +212,8 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt index cf5c68e220..0a68622ebe 100755 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,15 +1,29 @@ -add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) +add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) - endif() - -add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance) endif() -add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) +add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance) endif() +add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance) +endif() diff --git a/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp new file mode 100644 index 0000000000..22376a8599 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_BF16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); + +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp new file mode 100644 index 0000000000..1adee41ed2 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp new file mode 100644 index 0000000000..3579424496 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +#if CK_USE_WMMA_FP8 + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp8.inc" + +#endif // CK_USE_WMMA_FP8